repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pipe
Eric Bower  ·  2025-05-25

cli.go

  1package pipe
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12	"text/tabwriter"
 13	"time"
 14
 15	"github.com/antoniomika/syncmap"
 16	"github.com/google/uuid"
 17	"github.com/picosh/pico/pkg/db"
 18	"github.com/picosh/pico/pkg/pssh"
 19	"github.com/picosh/pico/pkg/shared"
 20	psub "github.com/picosh/pubsub"
 21	gossh "golang.org/x/crypto/ssh"
 22)
 23
 24func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
 25	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 26	cmd.SetOutput(sesh)
 27	cmd.Usage = func() {
 28		_, _ = fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
 29		cmd.PrintDefaults()
 30	}
 31	return cmd
 32}
 33
 34func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 35	err := cmd.Parse(cmdArgs)
 36
 37	if err != nil || posArg == "help" {
 38		if posArg == "help" {
 39			cmd.Usage()
 40		}
 41		return false
 42	}
 43	return true
 44}
 45
 46func NewTabWriter(out io.Writer) *tabwriter.Writer {
 47	return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
 48}
 49
 50// scope topic to user by prefixing name.
 51func toTopic(userName, topic string) string {
 52	return fmt.Sprintf("%s/%s", userName, topic)
 53}
 54
 55func toPublicTopic(topic string) string {
 56	return fmt.Sprintf("public/%s", topic)
 57}
 58
 59func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
 60	if len(clients) == 0 {
 61		return ""
 62	}
 63
 64	outputData := fmt.Sprintf("    %s:\r\n", clientType)
 65
 66	for _, client := range clients {
 67		if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
 68			continue
 69		}
 70
 71		outputData += fmt.Sprintf("    - %s\r\n", client.ID)
 72	}
 73
 74	return outputData
 75}
 76
 77var helpStr = func(sshCmd string) string {
 78	data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
 79
 80The simplest authenticated pubsub system.  Send messages through
 81user-defined topics.  Topics are private to the authenticated
 82ssh user.  The default pubsub model is multicast with bidirectional
 83blocking, meaning a publisher ("pub") will send its message to all
 84subscribers ("sub").  Further, both "pub" and "sub" will wait for
 85at least one event to be sent or received. Pipe ("pipe") allows
 86for bidirectional messages to be sent between any clients connected
 87to a pipe.
 88
 89Think of these different commands in terms of the direction the
 90data is being sent:
 91
 92- pub => writes to client
 93- sub => reads from client
 94- pipe => read and write between clients
 95`, sshCmd)
 96
 97	data = strings.ReplaceAll(data, "\n", "\r\n")
 98
 99	return data
100}
101
102type CliHandler struct {
103	DBPool  db.DB
104	Logger  *slog.Logger
105	PubSub  psub.PubSub
106	Cfg     *shared.ConfigSite
107	Waiters *syncmap.Map[string, []string]
108	Access  *syncmap.Map[string, []string]
109}
110
111func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
112	return h.Logger
113}
114
115func toSshCmd(cfg *shared.ConfigSite) string {
116	port := ""
117	if cfg.PortOverride != "22" {
118		port = fmt.Sprintf("-p %s ", cfg.PortOverride)
119	}
120	return fmt.Sprintf("%s%s", port, cfg.Domain)
121}
122
123// parseArgList parses a comma separated list of arguments.
124func parseArgList(arg string) []string {
125	argList := strings.Split(arg, ",")
126	for i, acc := range argList {
127		argList[i] = strings.TrimSpace(acc)
128	}
129	return argList
130}
131
132// checkAccess checks if the user has access to a topic based on an access list.
133func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
134	for _, acc := range accessList {
135		if acc == userName {
136			return true
137		}
138
139		if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
140			return true
141		}
142	}
143
144	return false
145}
146
147func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
148	pubsub := handler.PubSub
149
150	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
151		return func(sesh *pssh.SSHServerConnSession) error {
152			ctx := sesh.Context()
153			logger := pssh.GetLogger(sesh)
154			user := pssh.GetUser(sesh)
155
156			args := sesh.Command()
157
158			if len(args) == 0 {
159				_, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
160				return next(sesh)
161			}
162
163			userName := "public"
164
165			userNameAddition := ""
166
167			isAdmin := false
168			impersonate := false
169			if user != nil {
170				isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
171				if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
172					impersonate = true
173				}
174
175				userName = user.Name
176				if user.PublicKey != nil && user.PublicKey.Name != "" {
177					userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
178				}
179			}
180
181			pipeCtx, cancel := context.WithCancel(ctx)
182
183			go func() {
184				defer cancel()
185
186				for {
187					select {
188					case <-pipeCtx.Done():
189						return
190					default:
191						_, err := sesh.SendRequest("ping@pico.sh", false, nil)
192						if err != nil {
193							logger.Error("error sending ping", "err", err)
194							return
195						}
196
197						time.Sleep(5 * time.Second)
198					}
199				}
200			}()
201
202			cmd := strings.TrimSpace(args[0])
203			if cmd == "help" {
204				_, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
205				return next(sesh)
206			} else if cmd == "ls" {
207				if userName == "public" {
208					err := fmt.Errorf("access denied")
209					sesh.Fatal(err)
210					return err
211				}
212
213				topicFilter := fmt.Sprintf("%s/", userName)
214				if isAdmin {
215					topicFilter = ""
216					if len(args) > 1 {
217						topicFilter = args[1]
218					}
219				}
220
221				var channels []*psub.Channel
222				waitingChannels := map[string][]string{}
223
224				for topic, channel := range pubsub.GetChannels() {
225					if strings.HasPrefix(topic, topicFilter) {
226						channels = append(channels, channel)
227					}
228				}
229
230				for channel, clients := range handler.Waiters.Range {
231					if strings.HasPrefix(channel, topicFilter) {
232						waitingChannels[channel] = clients
233					}
234				}
235
236				if len(channels) == 0 && len(waitingChannels) == 0 {
237					_, _ = fmt.Fprintln(sesh, "no pubsub channels found")
238				} else {
239					var outputData string
240					if len(channels) > 0 || len(waitingChannels) > 0 {
241						outputData += "Channel Information\r\n"
242						for _, channel := range channels {
243							extraData := ""
244
245							if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
246								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
247							}
248
249							outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
250							outputData += "  Clients:\r\n"
251
252							var pubs []*psub.Client
253							var subs []*psub.Client
254							var pipes []*psub.Client
255
256							for _, client := range channel.GetClients() {
257								switch client.Direction {
258								case psub.ChannelDirectionInput:
259									pubs = append(pubs, client)
260								case psub.ChannelDirectionOutput:
261									subs = append(subs, client)
262								case psub.ChannelDirectionInputOutput:
263									pipes = append(pipes, client)
264								}
265							}
266							outputData += clientInfo(pubs, isAdmin, "Pubs")
267							outputData += clientInfo(subs, isAdmin, "Subs")
268							outputData += clientInfo(pipes, isAdmin, "Pipes")
269						}
270
271						for waitingChannel, channelPubs := range waitingChannels {
272							extraData := ""
273
274							if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
275								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
276							}
277
278							outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
279							outputData += "  Clients:\r\n"
280							outputData += fmt.Sprintf("    %s:\r\n", "Waiting Pubs")
281							for _, client := range channelPubs {
282								if strings.HasPrefix(client, "admin-") && !isAdmin {
283									continue
284								}
285								outputData += fmt.Sprintf("    - %s\r\n", client)
286							}
287						}
288					}
289
290					_, _ = sesh.Write([]byte(outputData))
291				}
292
293				return next(sesh)
294			}
295
296			topic := ""
297			cmdArgs := args[1:]
298			if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
299				topic = strings.TrimSpace(args[1])
300				cmdArgs = args[2:]
301			}
302
303			logger.Info(
304				"pubsub middleware detected command",
305				"args", args,
306				"cmd", cmd,
307				"topic", topic,
308				"cmdArgs", cmdArgs,
309			)
310
311			uuidStr := uuid.NewString()
312			if impersonate {
313				uuidStr = fmt.Sprintf("admin-%s", uuidStr)
314			}
315
316			clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
317
318			var err error
319
320			if cmd == "pub" {
321				pubCmd := flagSet("pub", sesh)
322				access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
323				empty := pubCmd.Bool("e", false, "Send an empty message to subs")
324				public := pubCmd.Bool("p", false, "Publish message to public topic")
325				block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
326				timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
327				clean := pubCmd.Bool("c", false, "Don't send status messages")
328
329				if !flagCheck(pubCmd, topic, cmdArgs) {
330					return err
331				}
332
333				if pubCmd.NArg() == 1 && topic == "" {
334					topic = pubCmd.Arg(0)
335				}
336
337				logger.Info(
338					"flags parsed",
339					"cmd", cmd,
340					"empty", *empty,
341					"public", *public,
342					"block", *block,
343					"timeout", *timeout,
344					"topic", topic,
345					"access", *access,
346					"clean", *clean,
347				)
348
349				var accessList []string
350
351				if *access != "" {
352					accessList = parseArgList(*access)
353				}
354
355				var rw io.ReadWriter
356				if *empty {
357					rw = bytes.NewBuffer(make([]byte, 1))
358				} else {
359					rw = sesh
360				}
361
362				if topic == "" {
363					topic = uuid.NewString()
364				}
365
366				var withoutUser string
367				var name string
368				msgFlag := ""
369
370				if isAdmin && strings.HasPrefix(topic, "/") {
371					name = strings.TrimPrefix(topic, "/")
372				} else {
373					name = toTopic(userName, topic)
374					if *public {
375						name = toPublicTopic(topic)
376						msgFlag = "-p "
377						withoutUser = name
378					} else {
379						withoutUser = topic
380					}
381				}
382
383				var accessListCreator bool
384
385				_, loaded := handler.Access.LoadOrStore(name, accessList)
386				if !loaded {
387					defer func() {
388						handler.Access.Delete(name)
389					}()
390
391					accessListCreator = true
392				}
393
394				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
395					if checkAccess(accessList, userName, sesh) || accessListCreator {
396						name = withoutUser
397					} else if !*public {
398						name = toTopic(userName, withoutUser)
399					} else {
400						topic = uuid.NewString()
401						name = toPublicTopic(topic)
402					}
403				}
404
405				if !*clean {
406					_, _ = fmt.Fprintf(
407						sesh,
408						"subscribe to this channel:\n  ssh %s sub %s%s\n",
409						toSshCmd(handler.Cfg),
410						msgFlag,
411						topic,
412					)
413				}
414
415				if *block {
416					count := 0
417					for topic, channel := range pubsub.GetChannels() {
418						if topic == name {
419							for _, client := range channel.GetClients() {
420								if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
421									count++
422								}
423							}
424							break
425						}
426					}
427
428					tt := *timeout
429					if count == 0 {
430						currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
431						handler.Waiters.Store(name, append(currentWaiters, clientID))
432
433						termMsg := "no subs found ... waiting"
434						if tt > 0 {
435							termMsg += " " + tt.String()
436						}
437
438						if !*clean {
439							_, _ = fmt.Fprintln(sesh, termMsg)
440						}
441
442						ready := make(chan struct{})
443
444						go func() {
445							for {
446								select {
447								case <-pipeCtx.Done():
448									cancel()
449									return
450								case <-time.After(1 * time.Millisecond):
451									count := 0
452									for topic, channel := range pubsub.GetChannels() {
453										if topic == name {
454											for _, client := range channel.GetClients() {
455												if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
456													count++
457												}
458											}
459											break
460										}
461									}
462
463									if count > 0 {
464										close(ready)
465										return
466									}
467								}
468							}
469						}()
470
471						select {
472						case <-ready:
473						case <-pipeCtx.Done():
474						case <-time.After(tt):
475							cancel()
476
477							if !*clean {
478								sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
479							} else {
480								err = sesh.Exit(1)
481								if err != nil {
482									logger.Error("error exiting session", "err", err)
483								}
484
485								_ = sesh.Close()
486							}
487						}
488
489						newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
490						newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
491							return cl == clientID
492						})
493						handler.Waiters.Store(name, newWaiters)
494
495						var toDelete []string
496
497						for channel, clients := range handler.Waiters.Range {
498							if len(clients) == 0 {
499								toDelete = append(toDelete, channel)
500							}
501						}
502
503						for _, channel := range toDelete {
504							handler.Waiters.Delete(channel)
505						}
506					}
507				}
508
509				if !*clean {
510					_, _ = fmt.Fprintln(sesh, "sending msg ...")
511				}
512
513				err = pubsub.Pub(
514					pipeCtx,
515					clientID,
516					rw,
517					[]*psub.Channel{
518						psub.NewChannel(name),
519					},
520					*block,
521				)
522
523				if !*clean {
524					_, _ = fmt.Fprintln(sesh, "msg sent!")
525				}
526
527				if err != nil && !*clean {
528					_, _ = fmt.Fprintln(sesh.Stderr(), err)
529				}
530			} else if cmd == "sub" {
531				subCmd := flagSet("sub", sesh)
532				access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
533				public := subCmd.Bool("p", false, "Subscribe to a public topic")
534				keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
535				clean := subCmd.Bool("c", false, "Don't send status messages")
536
537				if !flagCheck(subCmd, topic, cmdArgs) {
538					return err
539				}
540
541				if subCmd.NArg() == 1 && topic == "" {
542					topic = subCmd.Arg(0)
543				}
544
545				logger.Info(
546					"flags parsed",
547					"cmd", cmd,
548					"public", *public,
549					"keepAlive", *keepAlive,
550					"topic", topic,
551					"clean", *clean,
552					"access", *access,
553				)
554
555				var accessList []string
556
557				if *access != "" {
558					accessList = parseArgList(*access)
559				}
560
561				var withoutUser string
562				var name string
563
564				if isAdmin && strings.HasPrefix(topic, "/") {
565					name = strings.TrimPrefix(topic, "/")
566				} else {
567					name = toTopic(userName, topic)
568					if *public {
569						name = toPublicTopic(topic)
570						withoutUser = name
571					} else {
572						withoutUser = topic
573					}
574				}
575
576				var accessListCreator bool
577
578				_, loaded := handler.Access.LoadOrStore(name, accessList)
579				if !loaded {
580					defer func() {
581						handler.Access.Delete(name)
582					}()
583
584					accessListCreator = true
585				}
586
587				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
588					if checkAccess(accessList, userName, sesh) || accessListCreator {
589						name = withoutUser
590					} else if !*public {
591						name = toTopic(userName, withoutUser)
592					} else {
593						_, _ = fmt.Fprintln(sesh.Stderr(), "access denied")
594						return err
595					}
596				}
597
598				err = pubsub.Sub(
599					pipeCtx,
600					clientID,
601					sesh,
602					[]*psub.Channel{
603						psub.NewChannel(name),
604					},
605					*keepAlive,
606				)
607
608				if err != nil && !*clean {
609					_, _ = fmt.Fprintln(sesh.Stderr(), err)
610				}
611			} else if cmd == "pipe" {
612				pipeCmd := flagSet("pipe", sesh)
613				access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
614				public := pipeCmd.Bool("p", false, "Pipe to a public topic")
615				replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
616				clean := pipeCmd.Bool("c", false, "Don't send status messages")
617
618				if !flagCheck(pipeCmd, topic, cmdArgs) {
619					return err
620				}
621
622				if pipeCmd.NArg() == 1 && topic == "" {
623					topic = pipeCmd.Arg(0)
624				}
625
626				logger.Info(
627					"flags parsed",
628					"cmd", cmd,
629					"public", *public,
630					"replay", *replay,
631					"topic", topic,
632					"access", *access,
633					"clean", *clean,
634				)
635
636				var accessList []string
637
638				if *access != "" {
639					accessList = parseArgList(*access)
640				}
641
642				isCreator := topic == ""
643				if isCreator {
644					topic = uuid.NewString()
645				}
646
647				var withoutUser string
648				var name string
649				flagMsg := ""
650
651				if isAdmin && strings.HasPrefix(topic, "/") {
652					name = strings.TrimPrefix(topic, "/")
653				} else {
654					name = toTopic(userName, topic)
655					if *public {
656						name = toPublicTopic(topic)
657						flagMsg = "-p "
658						withoutUser = name
659					} else {
660						withoutUser = topic
661					}
662				}
663
664				var accessListCreator bool
665
666				_, loaded := handler.Access.LoadOrStore(name, accessList)
667				if !loaded {
668					defer func() {
669						handler.Access.Delete(name)
670					}()
671
672					accessListCreator = true
673				}
674
675				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
676					if checkAccess(accessList, userName, sesh) || accessListCreator {
677						name = withoutUser
678					} else if !*public {
679						name = toTopic(userName, withoutUser)
680					} else {
681						topic = uuid.NewString()
682						name = toPublicTopic(topic)
683					}
684				}
685
686				if isCreator && !*clean {
687					_, _ = fmt.Fprintf(
688						sesh,
689						"subscribe to this topic:\n  ssh %s sub %s%s\n",
690						toSshCmd(handler.Cfg),
691						flagMsg,
692						topic,
693					)
694				}
695
696				readErr, writeErr := pubsub.Pipe(
697					pipeCtx,
698					clientID,
699					sesh,
700					[]*psub.Channel{
701						psub.NewChannel(name),
702					},
703					*replay,
704				)
705
706				if readErr != nil && !*clean {
707					_, _ = fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
708				}
709
710				if writeErr != nil && !*clean {
711					_, _ = fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
712				}
713			}
714
715			return next(sesh)
716		}
717	}
718}