Eric Bower
·
2025-04-08
router.go
1package shared
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/http"
9 "net/http/pprof"
10 "regexp"
11 "strings"
12
13 "github.com/hashicorp/golang-lru/v2/expirable"
14 "github.com/picosh/pico/pkg/cache"
15 "github.com/picosh/pico/pkg/db"
16 "github.com/picosh/pico/pkg/pssh"
17 "github.com/picosh/pico/pkg/shared/storage"
18)
19
20type Route struct {
21 Method string
22 Regex *regexp.Regexp
23 Handler http.HandlerFunc
24 CorsEnabled bool
25}
26
27func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
28 return Route{
29 method,
30 regexp.MustCompile("^" + pattern + "$"),
31 handler,
32 false,
33 }
34}
35
36func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
37 return Route{
38 method,
39 regexp.MustCompile("^" + pattern + "$"),
40 handler,
41 true,
42 }
43}
44
45func CreatePProfRoutes(routes []Route) []Route {
46 return append(routes,
47 NewRoute("GET", "/debug/pprof/cmdline", pprof.Cmdline),
48 NewRoute("GET", "/debug/pprof/profile", pprof.Profile),
49 NewRoute("GET", "/debug/pprof/symbol", pprof.Symbol),
50 NewRoute("GET", "/debug/pprof/trace", pprof.Trace),
51 NewRoute("GET", "/debug/pprof/(.*)", pprof.Index),
52 NewRoute("POST", "/debug/pprof/cmdline", pprof.Cmdline),
53 NewRoute("POST", "/debug/pprof/profile", pprof.Profile),
54 NewRoute("POST", "/debug/pprof/symbol", pprof.Symbol),
55 NewRoute("POST", "/debug/pprof/trace", pprof.Trace),
56 NewRoute("POST", "/debug/pprof/(.*)", pprof.Index),
57 )
58}
59
60func CreatePProfRoutesMux(mux *http.ServeMux) {
61 mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
62 mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
63 mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
64 mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
65 mux.HandleFunc("GET /debug/pprof/(.*)", pprof.Index)
66 mux.HandleFunc("POST /debug/pprof/cmdline", pprof.Cmdline)
67 mux.HandleFunc("POST /debug/pprof/profile", pprof.Profile)
68 mux.HandleFunc("POST /debug/pprof/symbol", pprof.Symbol)
69 mux.HandleFunc("POST /debug/pprof/trace", pprof.Trace)
70 mux.HandleFunc("POST /debug/pprof/(.*)", pprof.Index)
71}
72
73type ApiConfig struct {
74 Cfg *ConfigSite
75 Dbpool db.DB
76 Storage storage.StorageServe
77}
78
79func (hc *ApiConfig) HasPrivilegedAccess(apiToken string) bool {
80 user, err := hc.Dbpool.FindUserForToken(apiToken)
81 if err != nil {
82 return false
83 }
84 return hc.Dbpool.HasFeatureForUser(user.ID, "auth")
85}
86
87func (hc *ApiConfig) HasPlusOrSpace(user *db.User, space string) bool {
88 return hc.Dbpool.HasFeatureForUser(user.ID, "plus") || hc.Dbpool.HasFeatureForUser(user.ID, space)
89}
90
91func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
92 ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
93 ctx = context.WithValue(ctx, CtxSubdomainKey{}, subdomain)
94 ctx = context.WithValue(ctx, ctxDBKey{}, hc.Dbpool)
95 ctx = context.WithValue(ctx, ctxStorageKey{}, hc.Storage)
96 ctx = context.WithValue(ctx, ctxCfg{}, hc.Cfg)
97 return ctx
98}
99
100func CreateServeBasic(routes []Route, ctx context.Context) http.HandlerFunc {
101 return func(w http.ResponseWriter, r *http.Request) {
102 var allow []string
103 for _, route := range routes {
104 matches := route.Regex.FindStringSubmatch(r.URL.Path)
105 if len(matches) > 0 {
106 if r.Method == "OPTIONS" && route.CorsEnabled {
107 CorsHeaders(w.Header())
108 w.WriteHeader(http.StatusOK)
109 return
110 } else if r.Method != route.Method {
111 allow = append(allow, route.Method)
112 continue
113 }
114
115 if route.CorsEnabled {
116 CorsHeaders(w.Header())
117 }
118
119 finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
120 route.Handler(w, r.WithContext(finctx))
121 return
122 }
123 }
124 if len(allow) > 0 {
125 w.Header().Set("Allow", strings.Join(allow, ", "))
126 http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
127 return
128 }
129 http.NotFound(w, r)
130 }
131}
132
133func GetSubdomainFromRequest(r *http.Request, domain, space string) string {
134 hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
135 appDomain := strings.ToLower(strings.Split(domain, ":")[0])
136
137 if hostDomain != appDomain {
138 if strings.Contains(hostDomain, appDomain) {
139 subdomain := strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
140 return subdomain
141 } else {
142 subdomain := GetCustomDomain(hostDomain, space)
143 return subdomain
144 }
145 }
146
147 return ""
148}
149
150func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
151 if len(subdomainRoutes) == 0 {
152 return routes, ""
153 }
154
155 subdomain := GetSubdomainFromRequest(r, cfg.Domain, cfg.Space)
156 if subdomain == "" {
157 return routes, subdomain
158 }
159 return subdomainRoutes, subdomain
160}
161
162func CreateServe(routes []Route, subdomainRoutes []Route, apiConfig *ApiConfig) http.HandlerFunc {
163 return func(w http.ResponseWriter, r *http.Request) {
164 curRoutes, subdomain := findRouteConfig(r, routes, subdomainRoutes, apiConfig.Cfg)
165 ctx := apiConfig.CreateCtx(r.Context(), subdomain)
166 router := CreateServeBasic(curRoutes, ctx)
167 router(w, r)
168 }
169}
170
171type ctxDBKey struct{}
172type ctxStorageKey struct{}
173type ctxLoggerKey struct{}
174type ctxCfg struct{}
175
176type CtxSubdomainKey struct{}
177type ctxKey struct{}
178type CtxSessionKey struct{}
179
180func GetSshCtx(r *http.Request) (*pssh.SSHServerConnSession, error) {
181 payload, ok := r.Context().Value(CtxSessionKey{}).(*pssh.SSHServerConnSession)
182 if payload == nil || !ok {
183 return payload, fmt.Errorf("ssh session not set on `r.Context()` for connection")
184 }
185 return payload, nil
186}
187
188func GetCfg(r *http.Request) *ConfigSite {
189 return r.Context().Value(ctxCfg{}).(*ConfigSite)
190}
191
192func GetLogger(r *http.Request) *slog.Logger {
193 return r.Context().Value(ctxLoggerKey{}).(*slog.Logger)
194}
195
196func GetDB(r *http.Request) db.DB {
197 return r.Context().Value(ctxDBKey{}).(db.DB)
198}
199
200func GetStorage(r *http.Request) storage.StorageServe {
201 return r.Context().Value(ctxStorageKey{}).(storage.StorageServe)
202}
203
204func GetField(r *http.Request, index int) string {
205 fields := r.Context().Value(ctxKey{}).([]string)
206 if index >= len(fields) {
207 return ""
208 }
209 return fields[index]
210}
211
212func GetSubdomain(r *http.Request) string {
213 return r.Context().Value(CtxSubdomainKey{}).(string)
214}
215
216var txtCache = expirable.NewLRU[string, string](2048, nil, cache.CacheTimeout)
217
218func GetCustomDomain(host string, space string) string {
219 txt := fmt.Sprintf("_%s.%s", space, host)
220 record, found := txtCache.Get(txt)
221 if found {
222 return record
223 }
224
225 records, err := net.LookupTXT(txt)
226 if err != nil {
227 return ""
228 }
229
230 for _, v := range records {
231 rec := strings.TrimSpace(v)
232 txtCache.Add(txt, rec)
233 return rec
234 }
235
236 return ""
237}
238
239func GetApiToken(r *http.Request) string {
240 authHeader := r.Header.Get("authorization")
241 if authHeader == "" {
242 return ""
243 }
244 return strings.TrimPrefix(authHeader, "Bearer ")
245}