repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pipe
Antonio Mika  ·  2025-03-12

api.go

  1package pipe
  2
  3import (
  4	"bufio"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"net/url"
 11	"os"
 12	"regexp"
 13	"strings"
 14	"sync"
 15	"time"
 16
 17	"github.com/google/uuid"
 18	"github.com/gorilla/websocket"
 19	"github.com/picosh/pico/pkg/db/postgres"
 20	"github.com/picosh/pico/pkg/shared"
 21	"github.com/picosh/utils/pipe"
 22	"github.com/prometheus/client_golang/prometheus/promhttp"
 23)
 24
 25var (
 26	cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`)
 27	sshClient  *pipe.Client
 28	upgrader   = websocket.Upgrader{
 29		CheckOrigin: func(r *http.Request) bool {
 30			return true
 31		},
 32	}
 33)
 34
 35func serveFile(file string, contentType string) http.HandlerFunc {
 36	return func(w http.ResponseWriter, r *http.Request) {
 37		logger := shared.GetLogger(r)
 38		cfg := shared.GetCfg(r)
 39
 40		contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file)))
 41		if err != nil {
 42			logger.Error("could not read statis file", "err", err.Error())
 43			http.Error(w, "file not found", 404)
 44		}
 45		w.Header().Add("Content-Type", contentType)
 46
 47		_, err = w.Write(contents)
 48		if err != nil {
 49			logger.Error("could not write static file", "err", err.Error())
 50			http.Error(w, "server error", http.StatusInternalServerError)
 51		}
 52	}
 53}
 54
 55func createStaticRoutes() []shared.Route {
 56	return []shared.Route{
 57		shared.NewRoute("GET", "/main.css", serveFile("main.css", "text/css")),
 58		shared.NewRoute("GET", "/smol.css", serveFile("smol.css", "text/css")),
 59		shared.NewRoute("GET", "/syntax.css", serveFile("syntax.css", "text/css")),
 60		shared.NewRoute("GET", "/card.png", serveFile("card.png", "image/png")),
 61		shared.NewRoute("GET", "/favicon-16x16.png", serveFile("favicon-16x16.png", "image/png")),
 62		shared.NewRoute("GET", "/favicon-32x32.png", serveFile("favicon-32x32.png", "image/png")),
 63		shared.NewRoute("GET", "/apple-touch-icon.png", serveFile("apple-touch-icon.png", "image/png")),
 64		shared.NewRoute("GET", "/favicon.ico", serveFile("favicon.ico", "image/x-icon")),
 65		shared.NewRoute("GET", "/robots.txt", serveFile("robots.txt", "text/plain")),
 66		shared.NewRoute("GET", "/anim.js", serveFile("anim.js", "text/javascript")),
 67	}
 68}
 69
 70type writeFlusher struct {
 71	responseWriter http.ResponseWriter
 72	controller     *http.ResponseController
 73}
 74
 75func (w writeFlusher) Write(p []byte) (n int, err error) {
 76	n, err = w.responseWriter.Write(p)
 77	if err == nil {
 78		err = w.controller.Flush()
 79	}
 80	return
 81}
 82
 83var _ io.Writer = writeFlusher{}
 84
 85func handleSub(pubsub bool) http.HandlerFunc {
 86	return func(w http.ResponseWriter, r *http.Request) {
 87		logger := shared.GetLogger(r)
 88
 89		clientInfo := shared.NewPicoPipeClient()
 90		topic, _ := url.PathUnescape(shared.GetField(r, 0))
 91
 92		topic = cleanRegex.ReplaceAllString(topic, "")
 93
 94		logger.Info("sub", "topic", topic, "info", clientInfo, "pubsub", pubsub)
 95
 96		params := "-p"
 97		if r.URL.Query().Get("persist") == "true" {
 98			params += " -k"
 99		}
100
101		if accessList := r.URL.Query().Get("access"); accessList != "" {
102			logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
103			cleanList := cleanRegex.ReplaceAllString(accessList, "")
104			params += fmt.Sprintf(" -a=%s", cleanList)
105		}
106
107		id := uuid.NewString()
108
109		p, err := sshClient.AddSession(id, fmt.Sprintf("sub %s %s", params, topic), 0, -1, -1)
110		if err != nil {
111			logger.Error("sub error", "topic", topic, "info", clientInfo, "err", err.Error())
112			http.Error(w, "server error", http.StatusInternalServerError)
113			return
114		}
115
116		go func() {
117			<-r.Context().Done()
118			err := sshClient.RemoveSession(id)
119			if err != nil {
120				logger.Error("sub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
121			}
122		}()
123
124		if mime := r.URL.Query().Get("mime"); mime != "" {
125			w.Header().Add("Content-Type", r.URL.Query().Get("mime"))
126		}
127
128		w.WriteHeader(http.StatusOK)
129
130		_, err = io.Copy(writeFlusher{w, http.NewResponseController(w)}, p)
131		if err != nil {
132			logger.Error("sub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
133			return
134		}
135	}
136}
137
138func handlePub(pubsub bool) http.HandlerFunc {
139	return func(w http.ResponseWriter, r *http.Request) {
140		logger := shared.GetLogger(r)
141
142		clientInfo := shared.NewPicoPipeClient()
143		topic, _ := url.PathUnescape(shared.GetField(r, 0))
144
145		topic = cleanRegex.ReplaceAllString(topic, "")
146
147		logger.Info("pub", "topic", topic, "info", clientInfo)
148
149		params := "-p"
150		if pubsub {
151			params += " -b=false"
152		}
153
154		if accessList := r.URL.Query().Get("access"); accessList != "" {
155			logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
156			cleanList := cleanRegex.ReplaceAllString(accessList, "")
157			params += fmt.Sprintf(" -a=%s", cleanList)
158		}
159
160		var wg sync.WaitGroup
161
162		reader := bufio.NewReaderSize(r.Body, 1)
163
164		first := make([]byte, 1)
165
166		nFirst, err := reader.Read(first)
167		if err != nil && !errors.Is(err, io.EOF) {
168			logger.Error("pub peek error", "topic", topic, "info", clientInfo, "err", err.Error())
169			http.Error(w, "server error", http.StatusInternalServerError)
170			return
171		}
172
173		if nFirst == 0 {
174			params += " -e"
175		}
176
177		id := uuid.NewString()
178
179		p, err := sshClient.AddSession(id, fmt.Sprintf("pub %s %s", params, topic), 0, -1, -1)
180		if err != nil {
181			logger.Error("pub error", "topic", topic, "info", clientInfo, "err", err.Error())
182			http.Error(w, "server error", http.StatusInternalServerError)
183			return
184		}
185
186		go func() {
187			<-r.Context().Done()
188			err := sshClient.RemoveSession(id)
189			if err != nil {
190				logger.Error("pub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
191			}
192		}()
193
194		var scanErr error
195		scanStatus := http.StatusInternalServerError
196
197		wg.Add(1)
198
199		go func() {
200			defer wg.Done()
201
202			s := bufio.NewScanner(p)
203			s.Buffer(make([]byte, 32*1024), 32*1024)
204
205			for s.Scan() {
206				if s.Text() == "sending msg ..." {
207					time.Sleep(10 * time.Millisecond)
208					break
209				}
210
211				if strings.HasPrefix(s.Text(), "  ssh ") {
212					f := strings.Fields(s.Text())
213					if len(f) > 1 && f[len(f)-1] != topic {
214						scanErr = fmt.Errorf("pub is not same as used, expected `%s` and received `%s`", topic, f[len(f)-1])
215						scanStatus = http.StatusUnauthorized
216						return
217					}
218				}
219			}
220
221			if err := s.Err(); err != nil {
222				scanErr = err
223				return
224			}
225		}()
226
227		wg.Wait()
228
229		if scanErr != nil {
230			logger.Error("pub scan error", "topic", topic, "info", clientInfo, "err", scanErr.Error())
231
232			msg := "server error"
233			if scanStatus == http.StatusUnauthorized {
234				msg = "access denied"
235			}
236
237			http.Error(w, msg, scanStatus)
238			return
239		}
240
241	outer:
242		for {
243			select {
244			case <-r.Context().Done():
245				break outer
246			default:
247				n, err := p.Write(first)
248				if err != nil {
249					logger.Error("pub write error", "topic", topic, "info", clientInfo, "err", err.Error())
250					http.Error(w, "server error", http.StatusInternalServerError)
251					return
252				}
253
254				if n > 0 {
255					break outer
256				}
257
258				time.Sleep(10 * time.Millisecond)
259			}
260		}
261
262		_, err = io.Copy(p, reader)
263		if err != nil {
264			logger.Error("pub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
265			http.Error(w, "server error", http.StatusInternalServerError)
266			return
267		}
268
269		w.WriteHeader(http.StatusOK)
270
271		time.Sleep(10 * time.Millisecond)
272	}
273}
274
275func handlePipe() http.HandlerFunc {
276	return func(w http.ResponseWriter, r *http.Request) {
277		logger := shared.GetLogger(r)
278
279		c, err := upgrader.Upgrade(w, r, nil)
280		if err != nil {
281			logger.Error("pipe upgrade error", "err", err.Error())
282			return
283		}
284
285		defer c.Close()
286
287		clientInfo := shared.NewPicoPipeClient()
288		topic, _ := url.PathUnescape(shared.GetField(r, 0))
289
290		topic = cleanRegex.ReplaceAllString(topic, "")
291
292		logger.Info("pipe", "topic", topic, "info", clientInfo)
293
294		params := "-p -c"
295		if r.URL.Query().Get("status") == "true" {
296			params = params[:len(params)-3]
297		}
298
299		if r.URL.Query().Get("replay") == "true" {
300			params += " -r"
301		}
302
303		messageType := websocket.TextMessage
304		if r.URL.Query().Get("binary") == "true" {
305			messageType = websocket.BinaryMessage
306		}
307
308		if accessList := r.URL.Query().Get("access"); accessList != "" {
309			logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
310			cleanList := cleanRegex.ReplaceAllString(accessList, "")
311			params += fmt.Sprintf(" -a=%s", cleanList)
312		}
313
314		id := uuid.NewString()
315
316		p, err := sshClient.AddSession(id, fmt.Sprintf("pipe %s %s", params, topic), 0, -1, -1)
317		if err != nil {
318			logger.Error("pipe error", "topic", topic, "info", clientInfo, "err", err.Error())
319			http.Error(w, "server error", http.StatusInternalServerError)
320			return
321		}
322
323		go func() {
324			<-r.Context().Done()
325			err := sshClient.RemoveSession(id)
326			if err != nil {
327				logger.Error("pipe remove error", "topic", topic, "info", clientInfo, "err", err.Error())
328			}
329			c.Close()
330		}()
331
332		var wg sync.WaitGroup
333		wg.Add(2)
334
335		go func() {
336			defer func() {
337				p.Close()
338				c.Close()
339				wg.Done()
340			}()
341
342			for {
343				_, message, err := c.ReadMessage()
344				if err != nil {
345					logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
346					break
347				}
348
349				_, err = p.Write(message)
350				if err != nil {
351					logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
352					break
353				}
354			}
355		}()
356
357		go func() {
358			defer func() {
359				p.Close()
360				c.Close()
361				wg.Done()
362			}()
363
364			for {
365				buf := make([]byte, 32*1024)
366
367				n, err := p.Read(buf)
368				if err != nil {
369					logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error())
370					break
371				}
372
373				buf = buf[:n]
374
375				err = c.WriteMessage(messageType, buf)
376				if err != nil {
377					logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error())
378					break
379				}
380			}
381		}()
382
383		wg.Wait()
384	}
385}
386
387func createMainRoutes(staticRoutes []shared.Route) []shared.Route {
388	routes := []shared.Route{
389		shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")),
390		shared.NewRoute("GET", "/check", shared.CheckHandler),
391		shared.NewRoute("GET", "/_metrics", promhttp.Handler().ServeHTTP),
392	}
393
394	pipeRoutes := []shared.Route{
395		shared.NewRoute("GET", "/topic/(.+)", handleSub(false)),
396		shared.NewRoute("POST", "/topic/(.+)", handlePub(false)),
397		shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)),
398		shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)),
399		shared.NewRoute("GET", "/pipe/(.+)", handlePipe()),
400	}
401
402	for _, route := range pipeRoutes {
403		route.CorsEnabled = true
404		routes = append(routes, route)
405	}
406
407	routes = append(
408		routes,
409		staticRoutes...,
410	)
411
412	return routes
413}
414
415func StartApiServer() {
416	cfg := NewConfigSite("pipe-web")
417	db := postgres.NewDB(cfg.DbURL, cfg.Logger)
418	defer db.Close()
419	logger := cfg.Logger
420
421	staticRoutes := createStaticRoutes()
422
423	if cfg.Debug {
424		staticRoutes = shared.CreatePProfRoutes(staticRoutes)
425	}
426
427	mainRoutes := createMainRoutes(staticRoutes)
428	subdomainRoutes := staticRoutes
429
430	info := shared.NewPicoPipeClient()
431
432	client, err := pipe.NewClient(context.Background(), logger.With("info", info), info)
433	if err != nil {
434		panic(err)
435	}
436
437	sshClient = client
438
439	pingSession, err := sshClient.AddSession("ping", "pub -b=false -c ping", 0, -1, -1)
440	if err != nil {
441		panic(err)
442	}
443
444	go func() {
445		for {
446			_, err := pingSession.Write([]byte(fmt.Sprintf("%s: pipe-web ping\n", time.Now().UTC().Format(time.RFC3339))))
447			if err != nil {
448				logger.Error("pipe ping error", "err", err.Error())
449			}
450
451			time.Sleep(5 * time.Second)
452		}
453	}()
454
455	apiConfig := &shared.ApiConfig{
456		Cfg:    cfg,
457		Dbpool: db,
458	}
459	handler := shared.CreateServe(mainRoutes, subdomainRoutes, apiConfig)
460	router := http.HandlerFunc(handler)
461
462	portStr := fmt.Sprintf(":%s", cfg.Port)
463	logger.Info(
464		"Starting server on port",
465		"port", cfg.Port,
466		"domain", cfg.Domain,
467	)
468
469	logger.Error("listen", "err", http.ListenAndServe(portStr, router).Error())
470}