123 lines
2.9 KiB
Go
123 lines
2.9 KiB
Go
package internal
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"git.sr.ht/~erock/wish/cms/db"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type Route struct {
|
|
method string
|
|
regex *regexp.Regexp
|
|
handler http.HandlerFunc
|
|
}
|
|
|
|
func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
|
|
return Route{
|
|
method,
|
|
regexp.MustCompile("^" + pattern + "$"),
|
|
handler,
|
|
}
|
|
}
|
|
|
|
type ServeFn func(http.ResponseWriter, *http.Request)
|
|
|
|
func CreateServe(routes []Route, subdomainRoutes []Route, cfg *ConfigSite, dbpool db.DB, logger *zap.SugaredLogger) ServeFn {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
var allow []string
|
|
curRoutes := routes
|
|
subdomain := GetRequestSubdomain(r)
|
|
|
|
if cfg.IsSubdomains() && subdomain != "" {
|
|
curRoutes = subdomainRoutes
|
|
}
|
|
|
|
for _, route := range curRoutes {
|
|
matches := route.regex.FindStringSubmatch(r.URL.Path)
|
|
if len(matches) > 0 {
|
|
if r.Method != route.method {
|
|
allow = append(allow, route.method)
|
|
continue
|
|
}
|
|
loggerCtx := context.WithValue(r.Context(), ctxLoggerKey{}, logger)
|
|
subdomainCtx := context.WithValue(loggerCtx, ctxSubdomainKey{}, subdomain)
|
|
dbCtx := context.WithValue(subdomainCtx, ctxDBKey{}, dbpool)
|
|
cfgCtx := context.WithValue(dbCtx, ctxCfg{}, cfg)
|
|
ctx := context.WithValue(cfgCtx, ctxKey{}, matches[1:])
|
|
route.handler(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
}
|
|
if len(allow) > 0 {
|
|
w.Header().Set("Allow", strings.Join(allow, ", "))
|
|
http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
http.NotFound(w, r)
|
|
}
|
|
}
|
|
|
|
type ctxDBKey struct{}
|
|
type ctxKey struct{}
|
|
type ctxLoggerKey struct{}
|
|
type ctxSubdomainKey struct{}
|
|
type ctxCfg struct{}
|
|
|
|
func GetCfg(r *http.Request) *ConfigSite {
|
|
return r.Context().Value(ctxCfg{}).(*ConfigSite)
|
|
}
|
|
|
|
func GetLogger(r *http.Request) *zap.SugaredLogger {
|
|
return r.Context().Value(ctxLoggerKey{}).(*zap.SugaredLogger)
|
|
}
|
|
|
|
func GetDB(r *http.Request) db.DB {
|
|
return r.Context().Value(ctxDBKey{}).(db.DB)
|
|
}
|
|
|
|
func GetField(r *http.Request, index int) string {
|
|
fields := r.Context().Value(ctxKey{}).([]string)
|
|
return fields[index]
|
|
}
|
|
|
|
func GetSubdomain(r *http.Request) string {
|
|
return r.Context().Value(ctxSubdomainKey{}).(string)
|
|
}
|
|
|
|
// https://stackoverflow.com/a/66445657/1713216
|
|
func GetRequestSubdomain(r *http.Request) string {
|
|
// The Host that the user queried.
|
|
host := r.Host
|
|
host = strings.TrimSpace(host)
|
|
// Figure out if a subdomain exists in the host given.
|
|
hostParts := strings.Split(host, ".")
|
|
|
|
lengthOfHostParts := len(hostParts)
|
|
|
|
// scenarios
|
|
// A. site.com -> length : 2
|
|
// B. www.site.com -> length : 3
|
|
// C. www.hello.site.com -> length : 4
|
|
|
|
if lengthOfHostParts == 4 {
|
|
// scenario C
|
|
return strings.Join([]string{hostParts[1]}, "")
|
|
}
|
|
|
|
// scenario B with a check
|
|
if lengthOfHostParts == 3 {
|
|
subdomain := strings.Join([]string{hostParts[0]}, "")
|
|
|
|
if subdomain == "www" {
|
|
return ""
|
|
} else {
|
|
return subdomain
|
|
}
|
|
}
|
|
|
|
return "" // scenario A
|
|
}
|