Eric Bower
·
2025-05-25
api.go
1package pipe
2
3import (
4 "bufio"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "os"
12 "regexp"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/google/uuid"
18 "github.com/gorilla/websocket"
19 "github.com/picosh/pico/pkg/db/postgres"
20 "github.com/picosh/pico/pkg/shared"
21 "github.com/picosh/utils/pipe"
22 "github.com/prometheus/client_golang/prometheus/promhttp"
23)
24
25var (
26 cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`)
27 sshClient *pipe.Client
28 upgrader = websocket.Upgrader{
29 CheckOrigin: func(r *http.Request) bool {
30 return true
31 },
32 }
33)
34
35func serveFile(file string, contentType string) http.HandlerFunc {
36 return func(w http.ResponseWriter, r *http.Request) {
37 logger := shared.GetLogger(r)
38 cfg := shared.GetCfg(r)
39
40 contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file)))
41 if err != nil {
42 logger.Error("could not read statis file", "err", err.Error())
43 http.Error(w, "file not found", 404)
44 }
45 w.Header().Add("Content-Type", contentType)
46
47 _, err = w.Write(contents)
48 if err != nil {
49 logger.Error("could not write static file", "err", err.Error())
50 http.Error(w, "server error", http.StatusInternalServerError)
51 }
52 }
53}
54
55func createStaticRoutes() []shared.Route {
56 return []shared.Route{
57 shared.NewRoute("GET", "/main.css", serveFile("main.css", "text/css")),
58 shared.NewRoute("GET", "/smol.css", serveFile("smol.css", "text/css")),
59 shared.NewRoute("GET", "/syntax.css", serveFile("syntax.css", "text/css")),
60 shared.NewRoute("GET", "/card.png", serveFile("card.png", "image/png")),
61 shared.NewRoute("GET", "/favicon-16x16.png", serveFile("favicon-16x16.png", "image/png")),
62 shared.NewRoute("GET", "/favicon-32x32.png", serveFile("favicon-32x32.png", "image/png")),
63 shared.NewRoute("GET", "/apple-touch-icon.png", serveFile("apple-touch-icon.png", "image/png")),
64 shared.NewRoute("GET", "/favicon.ico", serveFile("favicon.ico", "image/x-icon")),
65 shared.NewRoute("GET", "/robots.txt", serveFile("robots.txt", "text/plain")),
66 shared.NewRoute("GET", "/anim.js", serveFile("anim.js", "text/javascript")),
67 }
68}
69
70type writeFlusher struct {
71 responseWriter http.ResponseWriter
72 controller *http.ResponseController
73}
74
75func (w writeFlusher) Write(p []byte) (n int, err error) {
76 n, err = w.responseWriter.Write(p)
77 if err == nil {
78 err = w.controller.Flush()
79 }
80 return
81}
82
83var _ io.Writer = writeFlusher{}
84
85func handleSub(pubsub bool) http.HandlerFunc {
86 return func(w http.ResponseWriter, r *http.Request) {
87 logger := shared.GetLogger(r)
88
89 clientInfo := shared.NewPicoPipeClient()
90 topic, _ := url.PathUnescape(shared.GetField(r, 0))
91
92 topic = cleanRegex.ReplaceAllString(topic, "")
93
94 logger.Info("sub", "topic", topic, "info", clientInfo, "pubsub", pubsub)
95
96 params := "-p"
97 if r.URL.Query().Get("persist") == "true" {
98 params += " -k"
99 }
100
101 if accessList := r.URL.Query().Get("access"); accessList != "" {
102 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
103 cleanList := cleanRegex.ReplaceAllString(accessList, "")
104 params += fmt.Sprintf(" -a=%s", cleanList)
105 }
106
107 id := uuid.NewString()
108
109 p, err := sshClient.AddSession(id, fmt.Sprintf("sub %s %s", params, topic), 0, -1, -1)
110 if err != nil {
111 logger.Error("sub error", "topic", topic, "info", clientInfo, "err", err.Error())
112 http.Error(w, "server error", http.StatusInternalServerError)
113 return
114 }
115
116 go func() {
117 <-r.Context().Done()
118 err := sshClient.RemoveSession(id)
119 if err != nil {
120 logger.Error("sub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
121 }
122 }()
123
124 if mime := r.URL.Query().Get("mime"); mime != "" {
125 w.Header().Add("Content-Type", r.URL.Query().Get("mime"))
126 }
127
128 w.WriteHeader(http.StatusOK)
129
130 _, err = io.Copy(writeFlusher{w, http.NewResponseController(w)}, p)
131 if err != nil {
132 logger.Error("sub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
133 return
134 }
135 }
136}
137
138func handlePub(pubsub bool) http.HandlerFunc {
139 return func(w http.ResponseWriter, r *http.Request) {
140 logger := shared.GetLogger(r)
141
142 clientInfo := shared.NewPicoPipeClient()
143 topic, _ := url.PathUnescape(shared.GetField(r, 0))
144
145 topic = cleanRegex.ReplaceAllString(topic, "")
146
147 logger.Info("pub", "topic", topic, "info", clientInfo)
148
149 params := "-p"
150 if pubsub {
151 params += " -b=false"
152 }
153
154 if accessList := r.URL.Query().Get("access"); accessList != "" {
155 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
156 cleanList := cleanRegex.ReplaceAllString(accessList, "")
157 params += fmt.Sprintf(" -a=%s", cleanList)
158 }
159
160 var wg sync.WaitGroup
161
162 reader := bufio.NewReaderSize(r.Body, 1)
163
164 first := make([]byte, 1)
165
166 nFirst, err := reader.Read(first)
167 if err != nil && !errors.Is(err, io.EOF) {
168 logger.Error("pub peek error", "topic", topic, "info", clientInfo, "err", err.Error())
169 http.Error(w, "server error", http.StatusInternalServerError)
170 return
171 }
172
173 if nFirst == 0 {
174 params += " -e"
175 }
176
177 id := uuid.NewString()
178
179 p, err := sshClient.AddSession(id, fmt.Sprintf("pub %s %s", params, topic), 0, -1, -1)
180 if err != nil {
181 logger.Error("pub error", "topic", topic, "info", clientInfo, "err", err.Error())
182 http.Error(w, "server error", http.StatusInternalServerError)
183 return
184 }
185
186 go func() {
187 <-r.Context().Done()
188 err := sshClient.RemoveSession(id)
189 if err != nil {
190 logger.Error("pub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
191 }
192 }()
193
194 var scanErr error
195 scanStatus := http.StatusInternalServerError
196
197 wg.Add(1)
198
199 go func() {
200 defer wg.Done()
201
202 s := bufio.NewScanner(p)
203 s.Buffer(make([]byte, 32*1024), 32*1024)
204
205 for s.Scan() {
206 if s.Text() == "sending msg ..." {
207 time.Sleep(10 * time.Millisecond)
208 break
209 }
210
211 if strings.HasPrefix(s.Text(), " ssh ") {
212 f := strings.Fields(s.Text())
213 if len(f) > 1 && f[len(f)-1] != topic {
214 scanErr = fmt.Errorf("pub is not same as used, expected `%s` and received `%s`", topic, f[len(f)-1])
215 scanStatus = http.StatusUnauthorized
216 return
217 }
218 }
219 }
220
221 if err := s.Err(); err != nil {
222 scanErr = err
223 return
224 }
225 }()
226
227 wg.Wait()
228
229 if scanErr != nil {
230 logger.Error("pub scan error", "topic", topic, "info", clientInfo, "err", scanErr.Error())
231
232 msg := "server error"
233 if scanStatus == http.StatusUnauthorized {
234 msg = "access denied"
235 }
236
237 http.Error(w, msg, scanStatus)
238 return
239 }
240
241 outer:
242 for {
243 select {
244 case <-r.Context().Done():
245 break outer
246 default:
247 n, err := p.Write(first)
248 if err != nil {
249 logger.Error("pub write error", "topic", topic, "info", clientInfo, "err", err.Error())
250 http.Error(w, "server error", http.StatusInternalServerError)
251 return
252 }
253
254 if n > 0 {
255 break outer
256 }
257
258 time.Sleep(10 * time.Millisecond)
259 }
260 }
261
262 _, err = io.Copy(p, reader)
263 if err != nil {
264 logger.Error("pub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
265 http.Error(w, "server error", http.StatusInternalServerError)
266 return
267 }
268
269 w.WriteHeader(http.StatusOK)
270
271 time.Sleep(10 * time.Millisecond)
272 }
273}
274
275func handlePipe() http.HandlerFunc {
276 return func(w http.ResponseWriter, r *http.Request) {
277 logger := shared.GetLogger(r)
278
279 c, err := upgrader.Upgrade(w, r, nil)
280 if err != nil {
281 logger.Error("pipe upgrade error", "err", err.Error())
282 return
283 }
284
285 defer func() {
286 _ = c.Close()
287 }()
288
289 clientInfo := shared.NewPicoPipeClient()
290 topic, _ := url.PathUnescape(shared.GetField(r, 0))
291
292 topic = cleanRegex.ReplaceAllString(topic, "")
293
294 logger.Info("pipe", "topic", topic, "info", clientInfo)
295
296 params := "-p -c"
297 if r.URL.Query().Get("status") == "true" {
298 params = params[:len(params)-3]
299 }
300
301 if r.URL.Query().Get("replay") == "true" {
302 params += " -r"
303 }
304
305 messageType := websocket.TextMessage
306 if r.URL.Query().Get("binary") == "true" {
307 messageType = websocket.BinaryMessage
308 }
309
310 if accessList := r.URL.Query().Get("access"); accessList != "" {
311 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
312 cleanList := cleanRegex.ReplaceAllString(accessList, "")
313 params += fmt.Sprintf(" -a=%s", cleanList)
314 }
315
316 id := uuid.NewString()
317
318 p, err := sshClient.AddSession(id, fmt.Sprintf("pipe %s %s", params, topic), 0, -1, -1)
319 if err != nil {
320 logger.Error("pipe error", "topic", topic, "info", clientInfo, "err", err.Error())
321 http.Error(w, "server error", http.StatusInternalServerError)
322 return
323 }
324
325 go func() {
326 <-r.Context().Done()
327 err := sshClient.RemoveSession(id)
328 if err != nil {
329 logger.Error("pipe remove error", "topic", topic, "info", clientInfo, "err", err.Error())
330 }
331 _ = c.Close()
332 }()
333
334 var wg sync.WaitGroup
335 wg.Add(2)
336
337 go func() {
338 defer func() {
339 _ = p.Close()
340 _ = c.Close()
341 wg.Done()
342 }()
343
344 for {
345 _, message, err := c.ReadMessage()
346 if err != nil {
347 logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
348 break
349 }
350
351 _, err = p.Write(message)
352 if err != nil {
353 logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
354 break
355 }
356 }
357 }()
358
359 go func() {
360 defer func() {
361 _ = p.Close()
362 _ = c.Close()
363 wg.Done()
364 }()
365
366 for {
367 buf := make([]byte, 32*1024)
368
369 n, err := p.Read(buf)
370 if err != nil {
371 logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
372 break
373 }
374
375 buf = buf[:n]
376
377 err = c.WriteMessage(messageType, buf)
378 if err != nil {
379 logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
380 break
381 }
382 }
383 }()
384
385 wg.Wait()
386 }
387}
388
389func createMainRoutes(staticRoutes []shared.Route) []shared.Route {
390 routes := []shared.Route{
391 shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")),
392 shared.NewRoute("GET", "/check", shared.CheckHandler),
393 shared.NewRoute("GET", "/_metrics", promhttp.Handler().ServeHTTP),
394 }
395
396 pipeRoutes := []shared.Route{
397 shared.NewRoute("GET", "/topic/(.+)", handleSub(false)),
398 shared.NewRoute("POST", "/topic/(.+)", handlePub(false)),
399 shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)),
400 shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)),
401 shared.NewRoute("GET", "/pipe/(.+)", handlePipe()),
402 }
403
404 for _, route := range pipeRoutes {
405 route.CorsEnabled = true
406 routes = append(routes, route)
407 }
408
409 routes = append(
410 routes,
411 staticRoutes...,
412 )
413
414 return routes
415}
416
417func StartApiServer() {
418 cfg := NewConfigSite("pipe-web")
419 db := postgres.NewDB(cfg.DbURL, cfg.Logger)
420 defer func() {
421 _ = db.Close()
422 }()
423 logger := cfg.Logger
424
425 staticRoutes := createStaticRoutes()
426
427 if cfg.Debug {
428 staticRoutes = shared.CreatePProfRoutes(staticRoutes)
429 }
430
431 mainRoutes := createMainRoutes(staticRoutes)
432 subdomainRoutes := staticRoutes
433
434 info := shared.NewPicoPipeClient()
435
436 client, err := pipe.NewClient(context.Background(), logger.With("info", info), info)
437 if err != nil {
438 panic(err)
439 }
440
441 sshClient = client
442
443 pingSession, err := sshClient.AddSession("ping", "pub -b=false -c ping", 0, -1, -1)
444 if err != nil {
445 panic(err)
446 }
447
448 go func() {
449 for {
450 _, err := fmt.Fprintf(pingSession, "%s: pipe-web ping\n", time.Now().UTC().Format(time.RFC3339))
451 if err != nil {
452 logger.Error("pipe ping error", "err", err.Error())
453 }
454
455 time.Sleep(5 * time.Second)
456 }
457 }()
458
459 apiConfig := &shared.ApiConfig{
460 Cfg: cfg,
461 Dbpool: db,
462 }
463 handler := shared.CreateServe(mainRoutes, subdomainRoutes, apiConfig)
464 router := http.HandlerFunc(handler)
465
466 portStr := fmt.Sprintf(":%s", cfg.Port)
467 logger.Info(
468 "Starting server on port",
469 "port", cfg.Port,
470 "domain", cfg.Domain,
471 )
472
473 logger.Error("listen", "err", http.ListenAndServe(portStr, router).Error())
474}