repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / shared
Eric Bower  ·  2025-04-08

router.go

  1package shared
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/http/pprof"
 10	"regexp"
 11	"strings"
 12
 13	"github.com/hashicorp/golang-lru/v2/expirable"
 14	"github.com/picosh/pico/pkg/cache"
 15	"github.com/picosh/pico/pkg/db"
 16	"github.com/picosh/pico/pkg/pssh"
 17	"github.com/picosh/pico/pkg/shared/storage"
 18)
 19
 20type Route struct {
 21	Method      string
 22	Regex       *regexp.Regexp
 23	Handler     http.HandlerFunc
 24	CorsEnabled bool
 25}
 26
 27func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
 28	return Route{
 29		method,
 30		regexp.MustCompile("^" + pattern + "$"),
 31		handler,
 32		false,
 33	}
 34}
 35
 36func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
 37	return Route{
 38		method,
 39		regexp.MustCompile("^" + pattern + "$"),
 40		handler,
 41		true,
 42	}
 43}
 44
 45func CreatePProfRoutes(routes []Route) []Route {
 46	return append(routes,
 47		NewRoute("GET", "/debug/pprof/cmdline", pprof.Cmdline),
 48		NewRoute("GET", "/debug/pprof/profile", pprof.Profile),
 49		NewRoute("GET", "/debug/pprof/symbol", pprof.Symbol),
 50		NewRoute("GET", "/debug/pprof/trace", pprof.Trace),
 51		NewRoute("GET", "/debug/pprof/(.*)", pprof.Index),
 52		NewRoute("POST", "/debug/pprof/cmdline", pprof.Cmdline),
 53		NewRoute("POST", "/debug/pprof/profile", pprof.Profile),
 54		NewRoute("POST", "/debug/pprof/symbol", pprof.Symbol),
 55		NewRoute("POST", "/debug/pprof/trace", pprof.Trace),
 56		NewRoute("POST", "/debug/pprof/(.*)", pprof.Index),
 57	)
 58}
 59
 60func CreatePProfRoutesMux(mux *http.ServeMux) {
 61	mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
 62	mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
 63	mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
 64	mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
 65	mux.HandleFunc("GET /debug/pprof/(.*)", pprof.Index)
 66	mux.HandleFunc("POST /debug/pprof/cmdline", pprof.Cmdline)
 67	mux.HandleFunc("POST /debug/pprof/profile", pprof.Profile)
 68	mux.HandleFunc("POST /debug/pprof/symbol", pprof.Symbol)
 69	mux.HandleFunc("POST /debug/pprof/trace", pprof.Trace)
 70	mux.HandleFunc("POST /debug/pprof/(.*)", pprof.Index)
 71}
 72
 73type ApiConfig struct {
 74	Cfg     *ConfigSite
 75	Dbpool  db.DB
 76	Storage storage.StorageServe
 77}
 78
 79func (hc *ApiConfig) HasPrivilegedAccess(apiToken string) bool {
 80	user, err := hc.Dbpool.FindUserForToken(apiToken)
 81	if err != nil {
 82		return false
 83	}
 84	return hc.Dbpool.HasFeatureForUser(user.ID, "auth")
 85}
 86
 87func (hc *ApiConfig) HasPlusOrSpace(user *db.User, space string) bool {
 88	return hc.Dbpool.HasFeatureForUser(user.ID, "plus") || hc.Dbpool.HasFeatureForUser(user.ID, space)
 89}
 90
 91func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
 92	ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
 93	ctx = context.WithValue(ctx, CtxSubdomainKey{}, subdomain)
 94	ctx = context.WithValue(ctx, ctxDBKey{}, hc.Dbpool)
 95	ctx = context.WithValue(ctx, ctxStorageKey{}, hc.Storage)
 96	ctx = context.WithValue(ctx, ctxCfg{}, hc.Cfg)
 97	return ctx
 98}
 99
100func CreateServeBasic(routes []Route, ctx context.Context) http.HandlerFunc {
101	return func(w http.ResponseWriter, r *http.Request) {
102		var allow []string
103		for _, route := range routes {
104			matches := route.Regex.FindStringSubmatch(r.URL.Path)
105			if len(matches) > 0 {
106				if r.Method == "OPTIONS" && route.CorsEnabled {
107					CorsHeaders(w.Header())
108					w.WriteHeader(http.StatusOK)
109					return
110				} else if r.Method != route.Method {
111					allow = append(allow, route.Method)
112					continue
113				}
114
115				if route.CorsEnabled {
116					CorsHeaders(w.Header())
117				}
118
119				finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
120				route.Handler(w, r.WithContext(finctx))
121				return
122			}
123		}
124		if len(allow) > 0 {
125			w.Header().Set("Allow", strings.Join(allow, ", "))
126			http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
127			return
128		}
129		http.NotFound(w, r)
130	}
131}
132
133func GetSubdomainFromRequest(r *http.Request, domain, space string) string {
134	hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
135	appDomain := strings.ToLower(strings.Split(domain, ":")[0])
136
137	if hostDomain != appDomain {
138		if strings.Contains(hostDomain, appDomain) {
139			subdomain := strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
140			return subdomain
141		} else {
142			subdomain := GetCustomDomain(hostDomain, space)
143			return subdomain
144		}
145	}
146
147	return ""
148}
149
150func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
151	if len(subdomainRoutes) == 0 {
152		return routes, ""
153	}
154
155	subdomain := GetSubdomainFromRequest(r, cfg.Domain, cfg.Space)
156	if subdomain == "" {
157		return routes, subdomain
158	}
159	return subdomainRoutes, subdomain
160}
161
162func CreateServe(routes []Route, subdomainRoutes []Route, apiConfig *ApiConfig) http.HandlerFunc {
163	return func(w http.ResponseWriter, r *http.Request) {
164		curRoutes, subdomain := findRouteConfig(r, routes, subdomainRoutes, apiConfig.Cfg)
165		ctx := apiConfig.CreateCtx(r.Context(), subdomain)
166		router := CreateServeBasic(curRoutes, ctx)
167		router(w, r)
168	}
169}
170
171type ctxDBKey struct{}
172type ctxStorageKey struct{}
173type ctxLoggerKey struct{}
174type ctxCfg struct{}
175
176type CtxSubdomainKey struct{}
177type ctxKey struct{}
178type CtxSessionKey struct{}
179
180func GetSshCtx(r *http.Request) (*pssh.SSHServerConnSession, error) {
181	payload, ok := r.Context().Value(CtxSessionKey{}).(*pssh.SSHServerConnSession)
182	if payload == nil || !ok {
183		return payload, fmt.Errorf("ssh session not set on `r.Context()` for connection")
184	}
185	return payload, nil
186}
187
188func GetCfg(r *http.Request) *ConfigSite {
189	return r.Context().Value(ctxCfg{}).(*ConfigSite)
190}
191
192func GetLogger(r *http.Request) *slog.Logger {
193	return r.Context().Value(ctxLoggerKey{}).(*slog.Logger)
194}
195
196func GetDB(r *http.Request) db.DB {
197	return r.Context().Value(ctxDBKey{}).(db.DB)
198}
199
200func GetStorage(r *http.Request) storage.StorageServe {
201	return r.Context().Value(ctxStorageKey{}).(storage.StorageServe)
202}
203
204func GetField(r *http.Request, index int) string {
205	fields := r.Context().Value(ctxKey{}).([]string)
206	if index >= len(fields) {
207		return ""
208	}
209	return fields[index]
210}
211
212func GetSubdomain(r *http.Request) string {
213	return r.Context().Value(CtxSubdomainKey{}).(string)
214}
215
216var txtCache = expirable.NewLRU[string, string](2048, nil, cache.CacheTimeout)
217
218func GetCustomDomain(host string, space string) string {
219	txt := fmt.Sprintf("_%s.%s", space, host)
220	record, found := txtCache.Get(txt)
221	if found {
222		return record
223	}
224
225	records, err := net.LookupTXT(txt)
226	if err != nil {
227		return ""
228	}
229
230	for _, v := range records {
231		rec := strings.TrimSpace(v)
232		txtCache.Add(txt, rec)
233		return rec
234	}
235
236	return ""
237}
238
239func GetApiToken(r *http.Request) string {
240	authHeader := r.Header.Get("authorization")
241	if authHeader == "" {
242		return ""
243	}
244	return strings.TrimPrefix(authHeader, "Bearer ")
245}