repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pipe
Eric Bower  ·  2025-05-25

api.go

  1package pipe
  2
  3import (
  4	"bufio"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"net/url"
 11	"os"
 12	"regexp"
 13	"strings"
 14	"sync"
 15	"time"
 16
 17	"github.com/google/uuid"
 18	"github.com/gorilla/websocket"
 19	"github.com/picosh/pico/pkg/db/postgres"
 20	"github.com/picosh/pico/pkg/shared"
 21	"github.com/picosh/utils/pipe"
 22	"github.com/prometheus/client_golang/prometheus/promhttp"
 23)
 24
 25var (
 26	cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`)
 27	sshClient  *pipe.Client
 28	upgrader   = websocket.Upgrader{
 29		CheckOrigin: func(r *http.Request) bool {
 30			return true
 31		},
 32	}
 33)
 34
 35func serveFile(file string, contentType string) http.HandlerFunc {
 36	return func(w http.ResponseWriter, r *http.Request) {
 37		logger := shared.GetLogger(r)
 38		cfg := shared.GetCfg(r)
 39
 40		contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file)))
 41		if err != nil {
 42			logger.Error("could not read statis file", "err", err.Error())
 43			http.Error(w, "file not found", 404)
 44		}
 45		w.Header().Add("Content-Type", contentType)
 46
 47		_, err = w.Write(contents)
 48		if err != nil {
 49			logger.Error("could not write static file", "err", err.Error())
 50			http.Error(w, "server error", http.StatusInternalServerError)
 51		}
 52	}
 53}
 54
 55func createStaticRoutes() []shared.Route {
 56	return []shared.Route{
 57		shared.NewRoute("GET", "/main.css", serveFile("main.css", "text/css")),
 58		shared.NewRoute("GET", "/smol.css", serveFile("smol.css", "text/css")),
 59		shared.NewRoute("GET", "/syntax.css", serveFile("syntax.css", "text/css")),
 60		shared.NewRoute("GET", "/card.png", serveFile("card.png", "image/png")),
 61		shared.NewRoute("GET", "/favicon-16x16.png", serveFile("favicon-16x16.png", "image/png")),
 62		shared.NewRoute("GET", "/favicon-32x32.png", serveFile("favicon-32x32.png", "image/png")),
 63		shared.NewRoute("GET", "/apple-touch-icon.png", serveFile("apple-touch-icon.png", "image/png")),
 64		shared.NewRoute("GET", "/favicon.ico", serveFile("favicon.ico", "image/x-icon")),
 65		shared.NewRoute("GET", "/robots.txt", serveFile("robots.txt", "text/plain")),
 66		shared.NewRoute("GET", "/anim.js", serveFile("anim.js", "text/javascript")),
 67	}
 68}
 69
 70type writeFlusher struct {
 71	responseWriter http.ResponseWriter
 72	controller     *http.ResponseController
 73}
 74
 75func (w writeFlusher) Write(p []byte) (n int, err error) {
 76	n, err = w.responseWriter.Write(p)
 77	if err == nil {
 78		err = w.controller.Flush()
 79	}
 80	return
 81}
 82
 83var _ io.Writer = writeFlusher{}
 84
 85func handleSub(pubsub bool) http.HandlerFunc {
 86	return func(w http.ResponseWriter, r *http.Request) {
 87		logger := shared.GetLogger(r)
 88
 89		clientInfo := shared.NewPicoPipeClient()
 90		topic, _ := url.PathUnescape(shared.GetField(r, 0))
 91
 92		topic = cleanRegex.ReplaceAllString(topic, "")
 93
 94		logger.Info("sub", "topic", topic, "info", clientInfo, "pubsub", pubsub)
 95
 96		params := "-p"
 97		if r.URL.Query().Get("persist") == "true" {
 98			params += " -k"
 99		}
100
101		if accessList := r.URL.Query().Get("access"); accessList != "" {
102			logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
103			cleanList := cleanRegex.ReplaceAllString(accessList, "")
104			params += fmt.Sprintf(" -a=%s", cleanList)
105		}
106
107		id := uuid.NewString()
108
109		p, err := sshClient.AddSession(id, fmt.Sprintf("sub %s %s", params, topic), 0, -1, -1)
110		if err != nil {
111			logger.Error("sub error", "topic", topic, "info", clientInfo, "err", err.Error())
112			http.Error(w, "server error", http.StatusInternalServerError)
113			return
114		}
115
116		go func() {
117			<-r.Context().Done()
118			err := sshClient.RemoveSession(id)
119			if err != nil {
120				logger.Error("sub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
121			}
122		}()
123
124		if mime := r.URL.Query().Get("mime"); mime != "" {
125			w.Header().Add("Content-Type", r.URL.Query().Get("mime"))
126		}
127
128		w.WriteHeader(http.StatusOK)
129
130		_, err = io.Copy(writeFlusher{w, http.NewResponseController(w)}, p)
131		if err != nil {
132			logger.Error("sub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
133			return
134		}
135	}
136}
137
138func handlePub(pubsub bool) http.HandlerFunc {
139	return func(w http.ResponseWriter, r *http.Request) {
140		logger := shared.GetLogger(r)
141
142		clientInfo := shared.NewPicoPipeClient()
143		topic, _ := url.PathUnescape(shared.GetField(r, 0))
144
145		topic = cleanRegex.ReplaceAllString(topic, "")
146
147		logger.Info("pub", "topic", topic, "info", clientInfo)
148
149		params := "-p"
150		if pubsub {
151			params += " -b=false"
152		}
153
154		if accessList := r.URL.Query().Get("access"); accessList != "" {
155			logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
156			cleanList := cleanRegex.ReplaceAllString(accessList, "")
157			params += fmt.Sprintf(" -a=%s", cleanList)
158		}
159
160		var wg sync.WaitGroup
161
162		reader := bufio.NewReaderSize(r.Body, 1)
163
164		first := make([]byte, 1)
165
166		nFirst, err := reader.Read(first)
167		if err != nil && !errors.Is(err, io.EOF) {
168			logger.Error("pub peek error", "topic", topic, "info", clientInfo, "err", err.Error())
169			http.Error(w, "server error", http.StatusInternalServerError)
170			return
171		}
172
173		if nFirst == 0 {
174			params += " -e"
175		}
176
177		id := uuid.NewString()
178
179		p, err := sshClient.AddSession(id, fmt.Sprintf("pub %s %s", params, topic), 0, -1, -1)
180		if err != nil {
181			logger.Error("pub error", "topic", topic, "info", clientInfo, "err", err.Error())
182			http.Error(w, "server error", http.StatusInternalServerError)
183			return
184		}
185
186		go func() {
187			<-r.Context().Done()
188			err := sshClient.RemoveSession(id)
189			if err != nil {
190				logger.Error("pub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
191			}
192		}()
193
194		var scanErr error
195		scanStatus := http.StatusInternalServerError
196
197		wg.Add(1)
198
199		go func() {
200			defer wg.Done()
201
202			s := bufio.NewScanner(p)
203			s.Buffer(make([]byte, 32*1024), 32*1024)
204
205			for s.Scan() {
206				if s.Text() == "sending msg ..." {
207					time.Sleep(10 * time.Millisecond)
208					break
209				}
210
211				if strings.HasPrefix(s.Text(), "  ssh ") {
212					f := strings.Fields(s.Text())
213					if len(f) > 1 && f[len(f)-1] != topic {
214						scanErr = fmt.Errorf("pub is not same as used, expected `%s` and received `%s`", topic, f[len(f)-1])
215						scanStatus = http.StatusUnauthorized
216						return
217					}
218				}
219			}
220
221			if err := s.Err(); err != nil {
222				scanErr = err
223				return
224			}
225		}()
226
227		wg.Wait()
228
229		if scanErr != nil {
230			logger.Error("pub scan error", "topic", topic, "info", clientInfo, "err", scanErr.Error())
231
232			msg := "server error"
233			if scanStatus == http.StatusUnauthorized {
234				msg = "access denied"
235			}
236
237			http.Error(w, msg, scanStatus)
238			return
239		}
240
241	outer:
242		for {
243			select {
244			case <-r.Context().Done():
245				break outer
246			default:
247				n, err := p.Write(first)
248				if err != nil {
249					logger.Error("pub write error", "topic", topic, "info", clientInfo, "err", err.Error())
250					http.Error(w, "server error", http.StatusInternalServerError)
251					return
252				}
253
254				if n > 0 {
255					break outer
256				}
257
258				time.Sleep(10 * time.Millisecond)
259			}
260		}
261
262		_, err = io.Copy(p, reader)
263		if err != nil {
264			logger.Error("pub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
265			http.Error(w, "server error", http.StatusInternalServerError)
266			return
267		}
268
269		w.WriteHeader(http.StatusOK)
270
271		time.Sleep(10 * time.Millisecond)
272	}
273}
274
275func handlePipe() http.HandlerFunc {
276	return func(w http.ResponseWriter, r *http.Request) {
277		logger := shared.GetLogger(r)
278
279		c, err := upgrader.Upgrade(w, r, nil)
280		if err != nil {
281			logger.Error("pipe upgrade error", "err", err.Error())
282			return
283		}
284
285		defer func() {
286			_ = c.Close()
287		}()
288
289		clientInfo := shared.NewPicoPipeClient()
290		topic, _ := url.PathUnescape(shared.GetField(r, 0))
291
292		topic = cleanRegex.ReplaceAllString(topic, "")
293
294		logger.Info("pipe", "topic", topic, "info", clientInfo)
295
296		params := "-p -c"
297		if r.URL.Query().Get("status") == "true" {
298			params = params[:len(params)-3]
299		}
300
301		if r.URL.Query().Get("replay") == "true" {
302			params += " -r"
303		}
304
305		messageType := websocket.TextMessage
306		if r.URL.Query().Get("binary") == "true" {
307			messageType = websocket.BinaryMessage
308		}
309
310		if accessList := r.URL.Query().Get("access"); accessList != "" {
311			logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
312			cleanList := cleanRegex.ReplaceAllString(accessList, "")
313			params += fmt.Sprintf(" -a=%s", cleanList)
314		}
315
316		id := uuid.NewString()
317
318		p, err := sshClient.AddSession(id, fmt.Sprintf("pipe %s %s", params, topic), 0, -1, -1)
319		if err != nil {
320			logger.Error("pipe error", "topic", topic, "info", clientInfo, "err", err.Error())
321			http.Error(w, "server error", http.StatusInternalServerError)
322			return
323		}
324
325		go func() {
326			<-r.Context().Done()
327			err := sshClient.RemoveSession(id)
328			if err != nil {
329				logger.Error("pipe remove error", "topic", topic, "info", clientInfo, "err", err.Error())
330			}
331			_ = c.Close()
332		}()
333
334		var wg sync.WaitGroup
335		wg.Add(2)
336
337		go func() {
338			defer func() {
339				_ = p.Close()
340				_ = c.Close()
341				wg.Done()
342			}()
343
344			for {
345				_, message, err := c.ReadMessage()
346				if err != nil {
347					logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
348					break
349				}
350
351				_, err = p.Write(message)
352				if err != nil {
353					logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
354					break
355				}
356			}
357		}()
358
359		go func() {
360			defer func() {
361				_ = p.Close()
362				_ = c.Close()
363				wg.Done()
364			}()
365
366			for {
367				buf := make([]byte, 32*1024)
368
369				n, err := p.Read(buf)
370				if err != nil {
371					logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
372					break
373				}
374
375				buf = buf[:n]
376
377				err = c.WriteMessage(messageType, buf)
378				if err != nil {
379					logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
380					break
381				}
382			}
383		}()
384
385		wg.Wait()
386	}
387}
388
389func createMainRoutes(staticRoutes []shared.Route) []shared.Route {
390	routes := []shared.Route{
391		shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")),
392		shared.NewRoute("GET", "/check", shared.CheckHandler),
393		shared.NewRoute("GET", "/_metrics", promhttp.Handler().ServeHTTP),
394	}
395
396	pipeRoutes := []shared.Route{
397		shared.NewRoute("GET", "/topic/(.+)", handleSub(false)),
398		shared.NewRoute("POST", "/topic/(.+)", handlePub(false)),
399		shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)),
400		shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)),
401		shared.NewRoute("GET", "/pipe/(.+)", handlePipe()),
402	}
403
404	for _, route := range pipeRoutes {
405		route.CorsEnabled = true
406		routes = append(routes, route)
407	}
408
409	routes = append(
410		routes,
411		staticRoutes...,
412	)
413
414	return routes
415}
416
417func StartApiServer() {
418	cfg := NewConfigSite("pipe-web")
419	db := postgres.NewDB(cfg.DbURL, cfg.Logger)
420	defer func() {
421		_ = db.Close()
422	}()
423	logger := cfg.Logger
424
425	staticRoutes := createStaticRoutes()
426
427	if cfg.Debug {
428		staticRoutes = shared.CreatePProfRoutes(staticRoutes)
429	}
430
431	mainRoutes := createMainRoutes(staticRoutes)
432	subdomainRoutes := staticRoutes
433
434	info := shared.NewPicoPipeClient()
435
436	client, err := pipe.NewClient(context.Background(), logger.With("info", info), info)
437	if err != nil {
438		panic(err)
439	}
440
441	sshClient = client
442
443	pingSession, err := sshClient.AddSession("ping", "pub -b=false -c ping", 0, -1, -1)
444	if err != nil {
445		panic(err)
446	}
447
448	go func() {
449		for {
450			_, err := fmt.Fprintf(pingSession, "%s: pipe-web ping\n", time.Now().UTC().Format(time.RFC3339))
451			if err != nil {
452				logger.Error("pipe ping error", "err", err.Error())
453			}
454
455			time.Sleep(5 * time.Second)
456		}
457	}()
458
459	apiConfig := &shared.ApiConfig{
460		Cfg:    cfg,
461		Dbpool: db,
462	}
463	handler := shared.CreateServe(mainRoutes, subdomainRoutes, apiConfig)
464	router := http.HandlerFunc(handler)
465
466	portStr := fmt.Sprintf(":%s", cfg.Port)
467	logger.Info(
468		"Starting server on port",
469		"port", cfg.Port,
470		"domain", cfg.Domain,
471	)
472
473	logger.Error("listen", "err", http.ListenAndServe(portStr, router).Error())
474}