Antonio Mika
·
2025-03-12
api.go
1package pipe
2
3import (
4 "bufio"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "os"
12 "regexp"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/google/uuid"
18 "github.com/gorilla/websocket"
19 "github.com/picosh/pico/pkg/db/postgres"
20 "github.com/picosh/pico/pkg/shared"
21 "github.com/picosh/utils/pipe"
22 "github.com/prometheus/client_golang/prometheus/promhttp"
23)
24
25var (
26 cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`)
27 sshClient *pipe.Client
28 upgrader = websocket.Upgrader{
29 CheckOrigin: func(r *http.Request) bool {
30 return true
31 },
32 }
33)
34
35func serveFile(file string, contentType string) http.HandlerFunc {
36 return func(w http.ResponseWriter, r *http.Request) {
37 logger := shared.GetLogger(r)
38 cfg := shared.GetCfg(r)
39
40 contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file)))
41 if err != nil {
42 logger.Error("could not read statis file", "err", err.Error())
43 http.Error(w, "file not found", 404)
44 }
45 w.Header().Add("Content-Type", contentType)
46
47 _, err = w.Write(contents)
48 if err != nil {
49 logger.Error("could not write static file", "err", err.Error())
50 http.Error(w, "server error", http.StatusInternalServerError)
51 }
52 }
53}
54
55func createStaticRoutes() []shared.Route {
56 return []shared.Route{
57 shared.NewRoute("GET", "/main.css", serveFile("main.css", "text/css")),
58 shared.NewRoute("GET", "/smol.css", serveFile("smol.css", "text/css")),
59 shared.NewRoute("GET", "/syntax.css", serveFile("syntax.css", "text/css")),
60 shared.NewRoute("GET", "/card.png", serveFile("card.png", "image/png")),
61 shared.NewRoute("GET", "/favicon-16x16.png", serveFile("favicon-16x16.png", "image/png")),
62 shared.NewRoute("GET", "/favicon-32x32.png", serveFile("favicon-32x32.png", "image/png")),
63 shared.NewRoute("GET", "/apple-touch-icon.png", serveFile("apple-touch-icon.png", "image/png")),
64 shared.NewRoute("GET", "/favicon.ico", serveFile("favicon.ico", "image/x-icon")),
65 shared.NewRoute("GET", "/robots.txt", serveFile("robots.txt", "text/plain")),
66 shared.NewRoute("GET", "/anim.js", serveFile("anim.js", "text/javascript")),
67 }
68}
69
70type writeFlusher struct {
71 responseWriter http.ResponseWriter
72 controller *http.ResponseController
73}
74
75func (w writeFlusher) Write(p []byte) (n int, err error) {
76 n, err = w.responseWriter.Write(p)
77 if err == nil {
78 err = w.controller.Flush()
79 }
80 return
81}
82
83var _ io.Writer = writeFlusher{}
84
85func handleSub(pubsub bool) http.HandlerFunc {
86 return func(w http.ResponseWriter, r *http.Request) {
87 logger := shared.GetLogger(r)
88
89 clientInfo := shared.NewPicoPipeClient()
90 topic, _ := url.PathUnescape(shared.GetField(r, 0))
91
92 topic = cleanRegex.ReplaceAllString(topic, "")
93
94 logger.Info("sub", "topic", topic, "info", clientInfo, "pubsub", pubsub)
95
96 params := "-p"
97 if r.URL.Query().Get("persist") == "true" {
98 params += " -k"
99 }
100
101 if accessList := r.URL.Query().Get("access"); accessList != "" {
102 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
103 cleanList := cleanRegex.ReplaceAllString(accessList, "")
104 params += fmt.Sprintf(" -a=%s", cleanList)
105 }
106
107 id := uuid.NewString()
108
109 p, err := sshClient.AddSession(id, fmt.Sprintf("sub %s %s", params, topic), 0, -1, -1)
110 if err != nil {
111 logger.Error("sub error", "topic", topic, "info", clientInfo, "err", err.Error())
112 http.Error(w, "server error", http.StatusInternalServerError)
113 return
114 }
115
116 go func() {
117 <-r.Context().Done()
118 err := sshClient.RemoveSession(id)
119 if err != nil {
120 logger.Error("sub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
121 }
122 }()
123
124 if mime := r.URL.Query().Get("mime"); mime != "" {
125 w.Header().Add("Content-Type", r.URL.Query().Get("mime"))
126 }
127
128 w.WriteHeader(http.StatusOK)
129
130 _, err = io.Copy(writeFlusher{w, http.NewResponseController(w)}, p)
131 if err != nil {
132 logger.Error("sub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
133 return
134 }
135 }
136}
137
138func handlePub(pubsub bool) http.HandlerFunc {
139 return func(w http.ResponseWriter, r *http.Request) {
140 logger := shared.GetLogger(r)
141
142 clientInfo := shared.NewPicoPipeClient()
143 topic, _ := url.PathUnescape(shared.GetField(r, 0))
144
145 topic = cleanRegex.ReplaceAllString(topic, "")
146
147 logger.Info("pub", "topic", topic, "info", clientInfo)
148
149 params := "-p"
150 if pubsub {
151 params += " -b=false"
152 }
153
154 if accessList := r.URL.Query().Get("access"); accessList != "" {
155 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
156 cleanList := cleanRegex.ReplaceAllString(accessList, "")
157 params += fmt.Sprintf(" -a=%s", cleanList)
158 }
159
160 var wg sync.WaitGroup
161
162 reader := bufio.NewReaderSize(r.Body, 1)
163
164 first := make([]byte, 1)
165
166 nFirst, err := reader.Read(first)
167 if err != nil && !errors.Is(err, io.EOF) {
168 logger.Error("pub peek error", "topic", topic, "info", clientInfo, "err", err.Error())
169 http.Error(w, "server error", http.StatusInternalServerError)
170 return
171 }
172
173 if nFirst == 0 {
174 params += " -e"
175 }
176
177 id := uuid.NewString()
178
179 p, err := sshClient.AddSession(id, fmt.Sprintf("pub %s %s", params, topic), 0, -1, -1)
180 if err != nil {
181 logger.Error("pub error", "topic", topic, "info", clientInfo, "err", err.Error())
182 http.Error(w, "server error", http.StatusInternalServerError)
183 return
184 }
185
186 go func() {
187 <-r.Context().Done()
188 err := sshClient.RemoveSession(id)
189 if err != nil {
190 logger.Error("pub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
191 }
192 }()
193
194 var scanErr error
195 scanStatus := http.StatusInternalServerError
196
197 wg.Add(1)
198
199 go func() {
200 defer wg.Done()
201
202 s := bufio.NewScanner(p)
203 s.Buffer(make([]byte, 32*1024), 32*1024)
204
205 for s.Scan() {
206 if s.Text() == "sending msg ..." {
207 time.Sleep(10 * time.Millisecond)
208 break
209 }
210
211 if strings.HasPrefix(s.Text(), " ssh ") {
212 f := strings.Fields(s.Text())
213 if len(f) > 1 && f[len(f)-1] != topic {
214 scanErr = fmt.Errorf("pub is not same as used, expected `%s` and received `%s`", topic, f[len(f)-1])
215 scanStatus = http.StatusUnauthorized
216 return
217 }
218 }
219 }
220
221 if err := s.Err(); err != nil {
222 scanErr = err
223 return
224 }
225 }()
226
227 wg.Wait()
228
229 if scanErr != nil {
230 logger.Error("pub scan error", "topic", topic, "info", clientInfo, "err", scanErr.Error())
231
232 msg := "server error"
233 if scanStatus == http.StatusUnauthorized {
234 msg = "access denied"
235 }
236
237 http.Error(w, msg, scanStatus)
238 return
239 }
240
241 outer:
242 for {
243 select {
244 case <-r.Context().Done():
245 break outer
246 default:
247 n, err := p.Write(first)
248 if err != nil {
249 logger.Error("pub write error", "topic", topic, "info", clientInfo, "err", err.Error())
250 http.Error(w, "server error", http.StatusInternalServerError)
251 return
252 }
253
254 if n > 0 {
255 break outer
256 }
257
258 time.Sleep(10 * time.Millisecond)
259 }
260 }
261
262 _, err = io.Copy(p, reader)
263 if err != nil {
264 logger.Error("pub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
265 http.Error(w, "server error", http.StatusInternalServerError)
266 return
267 }
268
269 w.WriteHeader(http.StatusOK)
270
271 time.Sleep(10 * time.Millisecond)
272 }
273}
274
275func handlePipe() http.HandlerFunc {
276 return func(w http.ResponseWriter, r *http.Request) {
277 logger := shared.GetLogger(r)
278
279 c, err := upgrader.Upgrade(w, r, nil)
280 if err != nil {
281 logger.Error("pipe upgrade error", "err", err.Error())
282 return
283 }
284
285 defer c.Close()
286
287 clientInfo := shared.NewPicoPipeClient()
288 topic, _ := url.PathUnescape(shared.GetField(r, 0))
289
290 topic = cleanRegex.ReplaceAllString(topic, "")
291
292 logger.Info("pipe", "topic", topic, "info", clientInfo)
293
294 params := "-p -c"
295 if r.URL.Query().Get("status") == "true" {
296 params = params[:len(params)-3]
297 }
298
299 if r.URL.Query().Get("replay") == "true" {
300 params += " -r"
301 }
302
303 messageType := websocket.TextMessage
304 if r.URL.Query().Get("binary") == "true" {
305 messageType = websocket.BinaryMessage
306 }
307
308 if accessList := r.URL.Query().Get("access"); accessList != "" {
309 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
310 cleanList := cleanRegex.ReplaceAllString(accessList, "")
311 params += fmt.Sprintf(" -a=%s", cleanList)
312 }
313
314 id := uuid.NewString()
315
316 p, err := sshClient.AddSession(id, fmt.Sprintf("pipe %s %s", params, topic), 0, -1, -1)
317 if err != nil {
318 logger.Error("pipe error", "topic", topic, "info", clientInfo, "err", err.Error())
319 http.Error(w, "server error", http.StatusInternalServerError)
320 return
321 }
322
323 go func() {
324 <-r.Context().Done()
325 err := sshClient.RemoveSession(id)
326 if err != nil {
327 logger.Error("pipe remove error", "topic", topic, "info", clientInfo, "err", err.Error())
328 }
329 c.Close()
330 }()
331
332 var wg sync.WaitGroup
333 wg.Add(2)
334
335 go func() {
336 defer func() {
337 p.Close()
338 c.Close()
339 wg.Done()
340 }()
341
342 for {
343 _, message, err := c.ReadMessage()
344 if err != nil {
345 logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
346 break
347 }
348
349 _, err = p.Write(message)
350 if err != nil {
351 logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
352 break
353 }
354 }
355 }()
356
357 go func() {
358 defer func() {
359 p.Close()
360 c.Close()
361 wg.Done()
362 }()
363
364 for {
365 buf := make([]byte, 32*1024)
366
367 n, err := p.Read(buf)
368 if err != nil {
369 logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
370 break
371 }
372
373 buf = buf[:n]
374
375 err = c.WriteMessage(messageType, buf)
376 if err != nil {
377 logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
378 break
379 }
380 }
381 }()
382
383 wg.Wait()
384 }
385}
386
387func createMainRoutes(staticRoutes []shared.Route) []shared.Route {
388 routes := []shared.Route{
389 shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")),
390 shared.NewRoute("GET", "/check", shared.CheckHandler),
391 shared.NewRoute("GET", "/_metrics", promhttp.Handler().ServeHTTP),
392 }
393
394 pipeRoutes := []shared.Route{
395 shared.NewRoute("GET", "/topic/(.+)", handleSub(false)),
396 shared.NewRoute("POST", "/topic/(.+)", handlePub(false)),
397 shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)),
398 shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)),
399 shared.NewRoute("GET", "/pipe/(.+)", handlePipe()),
400 }
401
402 for _, route := range pipeRoutes {
403 route.CorsEnabled = true
404 routes = append(routes, route)
405 }
406
407 routes = append(
408 routes,
409 staticRoutes...,
410 )
411
412 return routes
413}
414
415func StartApiServer() {
416 cfg := NewConfigSite("pipe-web")
417 db := postgres.NewDB(cfg.DbURL, cfg.Logger)
418 defer db.Close()
419 logger := cfg.Logger
420
421 staticRoutes := createStaticRoutes()
422
423 if cfg.Debug {
424 staticRoutes = shared.CreatePProfRoutes(staticRoutes)
425 }
426
427 mainRoutes := createMainRoutes(staticRoutes)
428 subdomainRoutes := staticRoutes
429
430 info := shared.NewPicoPipeClient()
431
432 client, err := pipe.NewClient(context.Background(), logger.With("info", info), info)
433 if err != nil {
434 panic(err)
435 }
436
437 sshClient = client
438
439 pingSession, err := sshClient.AddSession("ping", "pub -b=false -c ping", 0, -1, -1)
440 if err != nil {
441 panic(err)
442 }
443
444 go func() {
445 for {
446 _, err := pingSession.Write([]byte(fmt.Sprintf("%s: pipe-web ping\n", time.Now().UTC().Format(time.RFC3339))))
447 if err != nil {
448 logger.Error("pipe ping error", "err", err.Error())
449 }
450
451 time.Sleep(5 * time.Second)
452 }
453 }()
454
455 apiConfig := &shared.ApiConfig{
456 Cfg: cfg,
457 Dbpool: db,
458 }
459 handler := shared.CreateServe(mainRoutes, subdomainRoutes, apiConfig)
460 router := http.HandlerFunc(handler)
461
462 portStr := fmt.Sprintf(":%s", cfg.Port)
463 logger.Info(
464 "Starting server on port",
465 "port", cfg.Port,
466 "domain", cfg.Domain,
467 )
468
469 logger.Error("listen", "err", http.ListenAndServe(portStr, router).Error())
470}