repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pipe
Antonio Mika  ·  2025-03-15

cli.go

  1package pipe
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12	"text/tabwriter"
 13	"time"
 14
 15	"github.com/antoniomika/syncmap"
 16	"github.com/google/uuid"
 17	"github.com/picosh/pico/pkg/db"
 18	"github.com/picosh/pico/pkg/pssh"
 19	"github.com/picosh/pico/pkg/shared"
 20	psub "github.com/picosh/pubsub"
 21	gossh "golang.org/x/crypto/ssh"
 22)
 23
 24func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
 25	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 26	cmd.SetOutput(sesh)
 27	cmd.Usage = func() {
 28		fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
 29		cmd.PrintDefaults()
 30	}
 31	return cmd
 32}
 33
 34func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 35	err := cmd.Parse(cmdArgs)
 36
 37	if err != nil || posArg == "help" {
 38		if posArg == "help" {
 39			cmd.Usage()
 40		}
 41		return false
 42	}
 43	return true
 44}
 45
 46func NewTabWriter(out io.Writer) *tabwriter.Writer {
 47	return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
 48}
 49
 50// scope topic to user by prefixing name.
 51func toTopic(userName, topic string) string {
 52	return fmt.Sprintf("%s/%s", userName, topic)
 53}
 54
 55func toPublicTopic(topic string) string {
 56	return fmt.Sprintf("public/%s", topic)
 57}
 58
 59func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
 60	if len(clients) == 0 {
 61		return ""
 62	}
 63
 64	outputData := fmt.Sprintf("    %s:\r\n", clientType)
 65
 66	for _, client := range clients {
 67		if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
 68			continue
 69		}
 70
 71		outputData += fmt.Sprintf("    - %s\r\n", client.ID)
 72	}
 73
 74	return outputData
 75}
 76
 77var helpStr = func(sshCmd string) string {
 78	data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
 79
 80The simplest authenticated pubsub system.  Send messages through
 81user-defined topics.  Topics are private to the authenticated
 82ssh user.  The default pubsub model is multicast with bidirectional
 83blocking, meaning a publisher ("pub") will send its message to all
 84subscribers ("sub").  Further, both "pub" and "sub" will wait for
 85at least one event to be sent or received. Pipe ("pipe") allows
 86for bidirectional messages to be sent between any clients connected
 87to a pipe.
 88
 89Think of these different commands in terms of the direction the
 90data is being sent:
 91
 92- pub => writes to client
 93- sub => reads from client
 94- pipe => read and write between clients
 95`, sshCmd)
 96
 97	data = strings.ReplaceAll(data, "\n", "\r\n")
 98
 99	return data
100}
101
102type CliHandler struct {
103	DBPool  db.DB
104	Logger  *slog.Logger
105	PubSub  psub.PubSub
106	Cfg     *shared.ConfigSite
107	Waiters *syncmap.Map[string, []string]
108	Access  *syncmap.Map[string, []string]
109}
110
111func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
112	return h.Logger
113}
114
115func toSshCmd(cfg *shared.ConfigSite) string {
116	port := ""
117	if cfg.PortOverride != "22" {
118		port = fmt.Sprintf("-p %s ", cfg.PortOverride)
119	}
120	return fmt.Sprintf("%s%s", port, cfg.Domain)
121}
122
123// parseArgList parses a comma separated list of arguments.
124func parseArgList(arg string) []string {
125	argList := strings.Split(arg, ",")
126	for i, acc := range argList {
127		argList[i] = strings.TrimSpace(acc)
128	}
129	return argList
130}
131
132// checkAccess checks if the user has access to a topic based on an access list.
133func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
134	for _, acc := range accessList {
135		if acc == userName {
136			return true
137		}
138
139		if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
140			return true
141		}
142	}
143
144	return false
145}
146
147func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
148	pubsub := handler.PubSub
149
150	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
151		return func(sesh *pssh.SSHServerConnSession) error {
152			ctx := sesh.Context()
153			logger := pssh.GetLogger(sesh)
154			user := pssh.GetUser(sesh)
155
156			args := sesh.Command()
157
158			if len(args) == 0 {
159				fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
160				return next(sesh)
161			}
162
163			userName := "public"
164
165			userNameAddition := ""
166
167			isAdmin := false
168			impersonate := false
169			if user != nil {
170				isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
171				if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
172					impersonate = true
173				}
174
175				userName = user.Name
176				if user.PublicKey != nil && user.PublicKey.Name != "" {
177					userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
178				}
179			}
180
181			pipeCtx, cancel := context.WithCancel(ctx)
182
183			go func() {
184				defer cancel()
185
186				for {
187					select {
188					case <-pipeCtx.Done():
189						return
190					default:
191						_, err := sesh.SendRequest("ping@pico.sh", false, nil)
192						if err != nil {
193							logger.Error("error sending ping", "err", err)
194							return
195						}
196
197						time.Sleep(5 * time.Second)
198					}
199				}
200			}()
201
202			cmd := strings.TrimSpace(args[0])
203			if cmd == "help" {
204				fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
205				return next(sesh)
206			} else if cmd == "ls" {
207				if userName == "public" {
208					err := fmt.Errorf("access denied")
209					sesh.Fatal(err)
210					return err
211				}
212
213				topicFilter := fmt.Sprintf("%s/", userName)
214				if isAdmin {
215					topicFilter = ""
216					if len(args) > 1 {
217						topicFilter = args[1]
218					}
219				}
220
221				var channels []*psub.Channel
222				waitingChannels := map[string][]string{}
223
224				for topic, channel := range pubsub.GetChannels() {
225					if strings.HasPrefix(topic, topicFilter) {
226						channels = append(channels, channel)
227					}
228				}
229
230				for channel, clients := range handler.Waiters.Range {
231					if strings.HasPrefix(channel, topicFilter) {
232						waitingChannels[channel] = clients
233					}
234				}
235
236				if len(channels) == 0 && len(waitingChannels) == 0 {
237					fmt.Fprintln(sesh, "no pubsub channels found")
238				} else {
239					var outputData string
240					if len(channels) > 0 || len(waitingChannels) > 0 {
241						outputData += "Channel Information\r\n"
242						for _, channel := range channels {
243							extraData := ""
244
245							if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
246								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
247							}
248
249							outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
250							outputData += "  Clients:\r\n"
251
252							var pubs []*psub.Client
253							var subs []*psub.Client
254							var pipes []*psub.Client
255
256							for _, client := range channel.GetClients() {
257								if client.Direction == psub.ChannelDirectionInput {
258									pubs = append(pubs, client)
259								} else if client.Direction == psub.ChannelDirectionOutput {
260									subs = append(subs, client)
261								} else if client.Direction == psub.ChannelDirectionInputOutput {
262									pipes = append(pipes, client)
263								}
264							}
265							outputData += clientInfo(pubs, isAdmin, "Pubs")
266							outputData += clientInfo(subs, isAdmin, "Subs")
267							outputData += clientInfo(pipes, isAdmin, "Pipes")
268						}
269
270						for waitingChannel, channelPubs := range waitingChannels {
271							extraData := ""
272
273							if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
274								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
275							}
276
277							outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
278							outputData += "  Clients:\r\n"
279							outputData += fmt.Sprintf("    %s:\r\n", "Waiting Pubs")
280							for _, client := range channelPubs {
281								if strings.HasPrefix(client, "admin-") && !isAdmin {
282									continue
283								}
284								outputData += fmt.Sprintf("    - %s\r\n", client)
285							}
286						}
287					}
288
289					_, _ = sesh.Write([]byte(outputData))
290				}
291
292				return next(sesh)
293			}
294
295			topic := ""
296			cmdArgs := args[1:]
297			if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
298				topic = strings.TrimSpace(args[1])
299				cmdArgs = args[2:]
300			}
301
302			logger.Info(
303				"pubsub middleware detected command",
304				"args", args,
305				"cmd", cmd,
306				"topic", topic,
307				"cmdArgs", cmdArgs,
308			)
309
310			uuidStr := uuid.NewString()
311			if impersonate {
312				uuidStr = fmt.Sprintf("admin-%s", uuidStr)
313			}
314
315			clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
316
317			var err error
318
319			if cmd == "pub" {
320				pubCmd := flagSet("pub", sesh)
321				access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
322				empty := pubCmd.Bool("e", false, "Send an empty message to subs")
323				public := pubCmd.Bool("p", false, "Publish message to public topic")
324				block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
325				timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
326				clean := pubCmd.Bool("c", false, "Don't send status messages")
327
328				if !flagCheck(pubCmd, topic, cmdArgs) {
329					return err
330				}
331
332				if pubCmd.NArg() == 1 && topic == "" {
333					topic = pubCmd.Arg(0)
334				}
335
336				logger.Info(
337					"flags parsed",
338					"cmd", cmd,
339					"empty", *empty,
340					"public", *public,
341					"block", *block,
342					"timeout", *timeout,
343					"topic", topic,
344					"access", *access,
345					"clean", *clean,
346				)
347
348				var accessList []string
349
350				if *access != "" {
351					accessList = parseArgList(*access)
352				}
353
354				var rw io.ReadWriter
355				if *empty {
356					rw = bytes.NewBuffer(make([]byte, 1))
357				} else {
358					rw = sesh
359				}
360
361				if topic == "" {
362					topic = uuid.NewString()
363				}
364
365				var withoutUser string
366				var name string
367				msgFlag := ""
368
369				if isAdmin && strings.HasPrefix(topic, "/") {
370					name = strings.TrimPrefix(topic, "/")
371				} else {
372					name = toTopic(userName, topic)
373					if *public {
374						name = toPublicTopic(topic)
375						msgFlag = "-p "
376						withoutUser = name
377					} else {
378						withoutUser = topic
379					}
380				}
381
382				var accessListCreator bool
383
384				_, loaded := handler.Access.LoadOrStore(name, accessList)
385				if !loaded {
386					defer func() {
387						handler.Access.Delete(name)
388					}()
389
390					accessListCreator = true
391				}
392
393				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
394					if checkAccess(accessList, userName, sesh) || accessListCreator {
395						name = withoutUser
396					} else if !*public {
397						name = toTopic(userName, withoutUser)
398					} else {
399						topic = uuid.NewString()
400						name = toPublicTopic(topic)
401					}
402				}
403
404				if !*clean {
405					fmt.Fprintf(
406						sesh,
407						"subscribe to this channel:\n  ssh %s sub %s%s\n",
408						toSshCmd(handler.Cfg),
409						msgFlag,
410						topic,
411					)
412				}
413
414				if *block {
415					count := 0
416					for topic, channel := range pubsub.GetChannels() {
417						if topic == name {
418							for _, client := range channel.GetClients() {
419								if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
420									count++
421								}
422							}
423							break
424						}
425					}
426
427					tt := *timeout
428					if count == 0 {
429						currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
430						handler.Waiters.Store(name, append(currentWaiters, clientID))
431
432						termMsg := "no subs found ... waiting"
433						if tt > 0 {
434							termMsg += " " + tt.String()
435						}
436
437						if !*clean {
438							fmt.Fprintln(sesh, termMsg)
439						}
440
441						ready := make(chan struct{})
442
443						go func() {
444							for {
445								select {
446								case <-pipeCtx.Done():
447									cancel()
448									return
449								case <-time.After(1 * time.Millisecond):
450									count := 0
451									for topic, channel := range pubsub.GetChannels() {
452										if topic == name {
453											for _, client := range channel.GetClients() {
454												if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
455													count++
456												}
457											}
458											break
459										}
460									}
461
462									if count > 0 {
463										close(ready)
464										return
465									}
466								}
467							}
468						}()
469
470						select {
471						case <-ready:
472						case <-pipeCtx.Done():
473						case <-time.After(tt):
474							cancel()
475
476							if !*clean {
477								sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
478							} else {
479								err = sesh.Exit(1)
480								if err != nil {
481									logger.Error("error exiting session", "err", err)
482								}
483
484								sesh.Close()
485							}
486						}
487
488						newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
489						newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
490							return cl == clientID
491						})
492						handler.Waiters.Store(name, newWaiters)
493
494						var toDelete []string
495
496						for channel, clients := range handler.Waiters.Range {
497							if len(clients) == 0 {
498								toDelete = append(toDelete, channel)
499							}
500						}
501
502						for _, channel := range toDelete {
503							handler.Waiters.Delete(channel)
504						}
505					}
506				}
507
508				if !*clean {
509					fmt.Fprintln(sesh, "sending msg ...")
510				}
511
512				err = pubsub.Pub(
513					pipeCtx,
514					clientID,
515					rw,
516					[]*psub.Channel{
517						psub.NewChannel(name),
518					},
519					*block,
520				)
521
522				if !*clean {
523					fmt.Fprintln(sesh, "msg sent!")
524				}
525
526				if err != nil && !*clean {
527					fmt.Fprintln(sesh.Stderr(), err)
528				}
529			} else if cmd == "sub" {
530				subCmd := flagSet("sub", sesh)
531				access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
532				public := subCmd.Bool("p", false, "Subscribe to a public topic")
533				keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
534				clean := subCmd.Bool("c", false, "Don't send status messages")
535
536				if !flagCheck(subCmd, topic, cmdArgs) {
537					return err
538				}
539
540				if subCmd.NArg() == 1 && topic == "" {
541					topic = subCmd.Arg(0)
542				}
543
544				logger.Info(
545					"flags parsed",
546					"cmd", cmd,
547					"public", *public,
548					"keepAlive", *keepAlive,
549					"topic", topic,
550					"clean", *clean,
551					"access", *access,
552				)
553
554				var accessList []string
555
556				if *access != "" {
557					accessList = parseArgList(*access)
558				}
559
560				var withoutUser string
561				var name string
562
563				if isAdmin && strings.HasPrefix(topic, "/") {
564					name = strings.TrimPrefix(topic, "/")
565				} else {
566					name = toTopic(userName, topic)
567					if *public {
568						name = toPublicTopic(topic)
569						withoutUser = name
570					} else {
571						withoutUser = topic
572					}
573				}
574
575				var accessListCreator bool
576
577				_, loaded := handler.Access.LoadOrStore(name, accessList)
578				if !loaded {
579					defer func() {
580						handler.Access.Delete(name)
581					}()
582
583					accessListCreator = true
584				}
585
586				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
587					if checkAccess(accessList, userName, sesh) || accessListCreator {
588						name = withoutUser
589					} else if !*public {
590						name = toTopic(userName, withoutUser)
591					} else {
592						fmt.Fprintln(sesh.Stderr(), "access denied")
593						return err
594					}
595				}
596
597				err = pubsub.Sub(
598					pipeCtx,
599					clientID,
600					sesh,
601					[]*psub.Channel{
602						psub.NewChannel(name),
603					},
604					*keepAlive,
605				)
606
607				if err != nil && !*clean {
608					fmt.Fprintln(sesh.Stderr(), err)
609				}
610			} else if cmd == "pipe" {
611				pipeCmd := flagSet("pipe", sesh)
612				access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
613				public := pipeCmd.Bool("p", false, "Pipe to a public topic")
614				replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
615				clean := pipeCmd.Bool("c", false, "Don't send status messages")
616
617				if !flagCheck(pipeCmd, topic, cmdArgs) {
618					return err
619				}
620
621				if pipeCmd.NArg() == 1 && topic == "" {
622					topic = pipeCmd.Arg(0)
623				}
624
625				logger.Info(
626					"flags parsed",
627					"cmd", cmd,
628					"public", *public,
629					"replay", *replay,
630					"topic", topic,
631					"access", *access,
632					"clean", *clean,
633				)
634
635				var accessList []string
636
637				if *access != "" {
638					accessList = parseArgList(*access)
639				}
640
641				isCreator := topic == ""
642				if isCreator {
643					topic = uuid.NewString()
644				}
645
646				var withoutUser string
647				var name string
648				flagMsg := ""
649
650				if isAdmin && strings.HasPrefix(topic, "/") {
651					name = strings.TrimPrefix(topic, "/")
652				} else {
653					name = toTopic(userName, topic)
654					if *public {
655						name = toPublicTopic(topic)
656						flagMsg = "-p "
657						withoutUser = name
658					} else {
659						withoutUser = topic
660					}
661				}
662
663				var accessListCreator bool
664
665				_, loaded := handler.Access.LoadOrStore(name, accessList)
666				if !loaded {
667					defer func() {
668						handler.Access.Delete(name)
669					}()
670
671					accessListCreator = true
672				}
673
674				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
675					if checkAccess(accessList, userName, sesh) || accessListCreator {
676						name = withoutUser
677					} else if !*public {
678						name = toTopic(userName, withoutUser)
679					} else {
680						topic = uuid.NewString()
681						name = toPublicTopic(topic)
682					}
683				}
684
685				if isCreator && !*clean {
686					fmt.Fprintf(
687						sesh,
688						"subscribe to this topic:\n  ssh %s sub %s%s\n",
689						toSshCmd(handler.Cfg),
690						flagMsg,
691						topic,
692					)
693				}
694
695				readErr, writeErr := pubsub.Pipe(
696					pipeCtx,
697					clientID,
698					sesh,
699					[]*psub.Channel{
700						psub.NewChannel(name),
701					},
702					*replay,
703				)
704
705				if readErr != nil && !*clean {
706					fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
707				}
708
709				if writeErr != nil && !*clean {
710					fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
711				}
712			}
713
714			return next(sesh)
715		}
716	}
717}