pastes.sh/internal/router.go
2022-07-13 13:30:27 -04:00

123 lines
2.9 KiB
Go

package internal
import (
"context"
"net/http"
"regexp"
"strings"
"git.sr.ht/~erock/wish/cms/db"
"go.uber.org/zap"
)
type Route struct {
method string
regex *regexp.Regexp
handler http.HandlerFunc
}
func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
return Route{
method,
regexp.MustCompile("^" + pattern + "$"),
handler,
}
}
type ServeFn func(http.ResponseWriter, *http.Request)
func CreateServe(routes []Route, subdomainRoutes []Route, cfg *ConfigSite, dbpool db.DB, logger *zap.SugaredLogger) ServeFn {
return func(w http.ResponseWriter, r *http.Request) {
var allow []string
curRoutes := routes
subdomain := GetRequestSubdomain(r)
if cfg.IsSubdomains() && subdomain != "" {
curRoutes = subdomainRoutes
}
for _, route := range curRoutes {
matches := route.regex.FindStringSubmatch(r.URL.Path)
if len(matches) > 0 {
if r.Method != route.method {
allow = append(allow, route.method)
continue
}
loggerCtx := context.WithValue(r.Context(), ctxLoggerKey{}, logger)
subdomainCtx := context.WithValue(loggerCtx, ctxSubdomainKey{}, subdomain)
dbCtx := context.WithValue(subdomainCtx, ctxDBKey{}, dbpool)
cfgCtx := context.WithValue(dbCtx, ctxCfg{}, cfg)
ctx := context.WithValue(cfgCtx, ctxKey{}, matches[1:])
route.handler(w, r.WithContext(ctx))
return
}
}
if len(allow) > 0 {
w.Header().Set("Allow", strings.Join(allow, ", "))
http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
return
}
http.NotFound(w, r)
}
}
type ctxDBKey struct{}
type ctxKey struct{}
type ctxLoggerKey struct{}
type ctxSubdomainKey struct{}
type ctxCfg struct{}
func GetCfg(r *http.Request) *ConfigSite {
return r.Context().Value(ctxCfg{}).(*ConfigSite)
}
func GetLogger(r *http.Request) *zap.SugaredLogger {
return r.Context().Value(ctxLoggerKey{}).(*zap.SugaredLogger)
}
func GetDB(r *http.Request) db.DB {
return r.Context().Value(ctxDBKey{}).(db.DB)
}
func GetField(r *http.Request, index int) string {
fields := r.Context().Value(ctxKey{}).([]string)
return fields[index]
}
func GetSubdomain(r *http.Request) string {
return r.Context().Value(ctxSubdomainKey{}).(string)
}
// https://stackoverflow.com/a/66445657/1713216
func GetRequestSubdomain(r *http.Request) string {
// The Host that the user queried.
host := r.Host
host = strings.TrimSpace(host)
// Figure out if a subdomain exists in the host given.
hostParts := strings.Split(host, ".")
lengthOfHostParts := len(hostParts)
// scenarios
// A. site.com -> length : 2
// B. www.site.com -> length : 3
// C. www.hello.site.com -> length : 4
if lengthOfHostParts == 4 {
// scenario C
return strings.Join([]string{hostParts[1]}, "")
}
// scenario B with a check
if lengthOfHostParts == 3 {
subdomain := strings.Join([]string{hostParts[0]}, "")
if subdomain == "www" {
return ""
} else {
return subdomain
}
}
return "" // scenario A
}