repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pipe
Eric Bower  ·  2025-11-28

cli.go

  1package pipe
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12	"text/tabwriter"
 13	"time"
 14
 15	"github.com/antoniomika/syncmap"
 16	"github.com/google/uuid"
 17	"github.com/picosh/pico/pkg/db"
 18	"github.com/picosh/pico/pkg/pssh"
 19	"github.com/picosh/pico/pkg/shared"
 20	psub "github.com/picosh/pubsub"
 21	gossh "golang.org/x/crypto/ssh"
 22)
 23
 24func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
 25	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 26	cmd.SetOutput(sesh)
 27	cmd.Usage = func() {
 28		_, _ = fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
 29		cmd.PrintDefaults()
 30	}
 31	return cmd
 32}
 33
 34func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 35	err := cmd.Parse(cmdArgs)
 36
 37	if err != nil || posArg == "help" {
 38		if posArg == "help" {
 39			cmd.Usage()
 40		}
 41		return false
 42	}
 43	return true
 44}
 45
 46func NewTabWriter(out io.Writer) *tabwriter.Writer {
 47	return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
 48}
 49
 50// scope topic to user by prefixing name.
 51func toTopic(userName, topic string) string {
 52	if strings.HasPrefix(topic, userName+"/") {
 53		return topic
 54	}
 55	return fmt.Sprintf("%s/%s", userName, topic)
 56}
 57
 58func toPublicTopic(topic string) string {
 59	if strings.HasPrefix(topic, "public/") {
 60		return topic
 61	}
 62	return fmt.Sprintf("public/%s", topic)
 63}
 64
 65func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
 66	if len(clients) == 0 {
 67		return ""
 68	}
 69
 70	outputData := fmt.Sprintf("    %s:\r\n", clientType)
 71
 72	for _, client := range clients {
 73		if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
 74			continue
 75		}
 76
 77		outputData += fmt.Sprintf("    - %s\r\n", client.ID)
 78	}
 79
 80	return outputData
 81}
 82
 83var helpStr = func(sshCmd string) string {
 84	data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
 85
 86The simplest authenticated pubsub system.  Send messages through
 87user-defined topics.  Topics are private to the authenticated
 88ssh user.  The default pubsub model is multicast with bidirectional
 89blocking, meaning a publisher ("pub") will send its message to all
 90subscribers ("sub").  Further, both "pub" and "sub" will wait for
 91at least one event to be sent or received. Pipe ("pipe") allows
 92for bidirectional messages to be sent between any clients connected
 93to a pipe.
 94
 95Think of these different commands in terms of the direction the
 96data is being sent:
 97
 98- pub => writes to client
 99- sub => reads from client
100- pipe => read and write between clients
101`, sshCmd)
102
103	data = strings.ReplaceAll(data, "\n", "\r\n")
104
105	return data
106}
107
108type CliHandler struct {
109	DBPool  db.DB
110	Logger  *slog.Logger
111	PubSub  psub.PubSub
112	Cfg     *shared.ConfigSite
113	Waiters *syncmap.Map[string, []string]
114	Access  *syncmap.Map[string, []string]
115}
116
117func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
118	return h.Logger
119}
120
121func toSshCmd(cfg *shared.ConfigSite) string {
122	port := ""
123	if cfg.PortOverride != "22" {
124		port = fmt.Sprintf("-p %s ", cfg.PortOverride)
125	}
126	return fmt.Sprintf("%s%s", port, cfg.Domain)
127}
128
129// parseArgList parses a comma separated list of arguments.
130func parseArgList(arg string) []string {
131	argList := strings.Split(arg, ",")
132	for i, acc := range argList {
133		argList[i] = strings.TrimSpace(acc)
134	}
135	return argList
136}
137
138// checkAccess checks if the user has access to a topic based on an access list.
139func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
140	for _, acc := range accessList {
141		if acc == userName {
142			return true
143		}
144
145		if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
146			return true
147		}
148	}
149
150	return false
151}
152
153func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
154	pubsub := handler.PubSub
155
156	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
157		return func(sesh *pssh.SSHServerConnSession) error {
158			ctx := sesh.Context()
159			logger := pssh.GetLogger(sesh)
160			user := pssh.GetUser(sesh)
161
162			args := sesh.Command()
163
164			if len(args) == 0 {
165				_, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
166				return next(sesh)
167			}
168
169			userName := "public"
170
171			userNameAddition := ""
172
173			isAdmin := false
174			impersonate := false
175			if user != nil {
176				isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
177				if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
178					impersonate = true
179				}
180
181				userName = user.Name
182				if user.PublicKey != nil && user.PublicKey.Name != "" {
183					userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
184				}
185			}
186
187			pipeCtx, cancel := context.WithCancel(ctx)
188
189			go func() {
190				defer cancel()
191
192				for {
193					select {
194					case <-pipeCtx.Done():
195						return
196					default:
197						_, err := sesh.SendRequest("ping@pico.sh", false, nil)
198						if err != nil {
199							logger.Error("error sending ping", "err", err)
200							return
201						}
202
203						time.Sleep(5 * time.Second)
204					}
205				}
206			}()
207
208			cmd := strings.TrimSpace(args[0])
209			if cmd == "help" {
210				_, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
211				return next(sesh)
212			} else if cmd == "ls" {
213				if userName == "public" {
214					err := fmt.Errorf("access denied")
215					sesh.Fatal(err)
216					return err
217				}
218
219				topicFilter := fmt.Sprintf("%s/", userName)
220				if isAdmin {
221					topicFilter = ""
222					if len(args) > 1 {
223						topicFilter = args[1]
224					}
225				}
226
227				var channels []*psub.Channel
228				waitingChannels := map[string][]string{}
229
230				for topic, channel := range pubsub.GetChannels() {
231					if strings.HasPrefix(topic, topicFilter) {
232						channels = append(channels, channel)
233					}
234				}
235
236				for channel, clients := range handler.Waiters.Range {
237					if strings.HasPrefix(channel, topicFilter) {
238						waitingChannels[channel] = clients
239					}
240				}
241
242				if len(channels) == 0 && len(waitingChannels) == 0 {
243					_, _ = fmt.Fprintln(sesh, "no pubsub channels found")
244				} else {
245					var outputData string
246					if len(channels) > 0 || len(waitingChannels) > 0 {
247						outputData += "Channel Information\r\n"
248						for _, channel := range channels {
249							extraData := ""
250
251							if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
252								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
253							}
254
255							outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
256							outputData += "  Clients:\r\n"
257
258							var pubs []*psub.Client
259							var subs []*psub.Client
260							var pipes []*psub.Client
261
262							for _, client := range channel.GetClients() {
263								switch client.Direction {
264								case psub.ChannelDirectionInput:
265									pubs = append(pubs, client)
266								case psub.ChannelDirectionOutput:
267									subs = append(subs, client)
268								case psub.ChannelDirectionInputOutput:
269									pipes = append(pipes, client)
270								}
271							}
272							outputData += clientInfo(pubs, isAdmin, "Pubs")
273							outputData += clientInfo(subs, isAdmin, "Subs")
274							outputData += clientInfo(pipes, isAdmin, "Pipes")
275						}
276
277						for waitingChannel, channelPubs := range waitingChannels {
278							extraData := ""
279
280							if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
281								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
282							}
283
284							outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
285							outputData += "  Clients:\r\n"
286							outputData += fmt.Sprintf("    %s:\r\n", "Waiting Pubs")
287							for _, client := range channelPubs {
288								if strings.HasPrefix(client, "admin-") && !isAdmin {
289									continue
290								}
291								outputData += fmt.Sprintf("    - %s\r\n", client)
292							}
293						}
294					}
295
296					_, _ = sesh.Write([]byte(outputData))
297				}
298
299				return next(sesh)
300			}
301
302			topic := ""
303			cmdArgs := args[1:]
304			if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
305				topic = strings.TrimSpace(args[1])
306				cmdArgs = args[2:]
307			}
308
309			logger.Info(
310				"pubsub middleware detected command",
311				"args", args,
312				"cmd", cmd,
313				"topic", topic,
314				"cmdArgs", cmdArgs,
315			)
316
317			uuidStr := uuid.NewString()
318			if impersonate {
319				uuidStr = fmt.Sprintf("admin-%s", uuidStr)
320			}
321
322			clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
323
324			var err error
325
326			if cmd == "pub" {
327				pubCmd := flagSet("pub", sesh)
328				access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
329				empty := pubCmd.Bool("e", false, "Send an empty message to subs")
330				public := pubCmd.Bool("p", false, "Publish message to public topic")
331				block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
332				timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
333				clean := pubCmd.Bool("c", false, "Don't send status messages")
334
335				if !flagCheck(pubCmd, topic, cmdArgs) {
336					return err
337				}
338
339				if pubCmd.NArg() == 1 && topic == "" {
340					topic = pubCmd.Arg(0)
341				}
342
343				logger.Info(
344					"flags parsed",
345					"cmd", cmd,
346					"empty", *empty,
347					"public", *public,
348					"block", *block,
349					"timeout", *timeout,
350					"topic", topic,
351					"access", *access,
352					"clean", *clean,
353				)
354
355				var accessList []string
356
357				if *access != "" {
358					accessList = parseArgList(*access)
359				}
360
361				var rw io.ReadWriter
362				if *empty {
363					rw = bytes.NewBuffer(make([]byte, 1))
364				} else {
365					rw = sesh
366				}
367
368				if topic == "" {
369					topic = uuid.NewString()
370				}
371
372				var withoutUser string
373				var name string
374				msgFlag := ""
375
376				if isAdmin && strings.HasPrefix(topic, "/") {
377					name = strings.TrimPrefix(topic, "/")
378				} else {
379					name = toTopic(userName, topic)
380					if *public {
381						name = toPublicTopic(topic)
382						msgFlag = "-p "
383						withoutUser = name
384					} else {
385						withoutUser = topic
386					}
387				}
388
389				var accessListCreator bool
390
391				_, loaded := handler.Access.LoadOrStore(name, accessList)
392				if !loaded {
393					defer func() {
394						handler.Access.Delete(name)
395					}()
396
397					accessListCreator = true
398				}
399
400				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
401					if checkAccess(accessList, userName, sesh) || accessListCreator {
402						name = withoutUser
403					} else if !*public {
404						name = toTopic(userName, withoutUser)
405					} else {
406						topic = uuid.NewString()
407						name = toPublicTopic(topic)
408					}
409				}
410
411				if !*clean {
412					fmtTopic := topic
413					if *access != "" {
414						fmtTopic = fmt.Sprintf("%s/%s", userName, topic)
415					}
416
417					_, _ = fmt.Fprintf(
418						sesh,
419						"subscribe to this channel:\n  ssh %s sub %s%s\n",
420						toSshCmd(handler.Cfg),
421						msgFlag,
422						fmtTopic,
423					)
424				}
425
426				if *block {
427					count := 0
428					for topic, channel := range pubsub.GetChannels() {
429						if topic == name {
430							for _, client := range channel.GetClients() {
431								if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
432									count++
433								}
434							}
435							break
436						}
437					}
438
439					tt := *timeout
440					if count == 0 {
441						currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
442						handler.Waiters.Store(name, append(currentWaiters, clientID))
443
444						termMsg := "no subs found ... waiting"
445						if tt > 0 {
446							termMsg += " " + tt.String()
447						}
448
449						if !*clean {
450							_, _ = fmt.Fprintln(sesh, termMsg)
451						}
452
453						ready := make(chan struct{})
454
455						go func() {
456							for {
457								select {
458								case <-pipeCtx.Done():
459									cancel()
460									return
461								case <-time.After(1 * time.Millisecond):
462									count := 0
463									for topic, channel := range pubsub.GetChannels() {
464										if topic == name {
465											for _, client := range channel.GetClients() {
466												if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
467													count++
468												}
469											}
470											break
471										}
472									}
473
474									if count > 0 {
475										close(ready)
476										return
477									}
478								}
479							}
480						}()
481
482						select {
483						case <-ready:
484						case <-pipeCtx.Done():
485						case <-time.After(tt):
486							cancel()
487
488							if !*clean {
489								sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
490							} else {
491								err = sesh.Exit(1)
492								if err != nil {
493									logger.Error("error exiting session", "err", err)
494								}
495
496								_ = sesh.Close()
497							}
498						}
499
500						newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
501						newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
502							return cl == clientID
503						})
504						handler.Waiters.Store(name, newWaiters)
505
506						var toDelete []string
507
508						for channel, clients := range handler.Waiters.Range {
509							if len(clients) == 0 {
510								toDelete = append(toDelete, channel)
511							}
512						}
513
514						for _, channel := range toDelete {
515							handler.Waiters.Delete(channel)
516						}
517					}
518				}
519
520				if !*clean {
521					_, _ = fmt.Fprintln(sesh, "sending msg ...")
522				}
523
524				err = pubsub.Pub(
525					pipeCtx,
526					clientID,
527					rw,
528					[]*psub.Channel{
529						psub.NewChannel(name),
530					},
531					*block,
532				)
533
534				if !*clean {
535					_, _ = fmt.Fprintln(sesh, "msg sent!")
536				}
537
538				if err != nil && !*clean {
539					_, _ = fmt.Fprintln(sesh.Stderr(), err)
540				}
541			} else if cmd == "sub" {
542				subCmd := flagSet("sub", sesh)
543				access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
544				public := subCmd.Bool("p", false, "Subscribe to a public topic")
545				keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
546				clean := subCmd.Bool("c", false, "Don't send status messages")
547
548				if !flagCheck(subCmd, topic, cmdArgs) {
549					return err
550				}
551
552				if subCmd.NArg() == 1 && topic == "" {
553					topic = subCmd.Arg(0)
554				}
555
556				logger.Info(
557					"flags parsed",
558					"cmd", cmd,
559					"public", *public,
560					"keepAlive", *keepAlive,
561					"topic", topic,
562					"clean", *clean,
563					"access", *access,
564				)
565
566				var accessList []string
567
568				if *access != "" {
569					accessList = parseArgList(*access)
570				}
571
572				var withoutUser string
573				var name string
574
575				if isAdmin && strings.HasPrefix(topic, "/") {
576					name = strings.TrimPrefix(topic, "/")
577				} else {
578					name = toTopic(userName, topic)
579					if *public {
580						name = toPublicTopic(topic)
581						withoutUser = name
582					} else {
583						withoutUser = topic
584					}
585				}
586
587				var accessListCreator bool
588
589				_, loaded := handler.Access.LoadOrStore(name, accessList)
590				if !loaded {
591					defer func() {
592						handler.Access.Delete(name)
593					}()
594
595					accessListCreator = true
596				}
597
598				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
599					if checkAccess(accessList, userName, sesh) || accessListCreator {
600						name = withoutUser
601					} else if !*public {
602						name = toTopic(userName, withoutUser)
603					} else {
604						_, _ = fmt.Fprintln(sesh.Stderr(), "access denied")
605						return err
606					}
607				}
608
609				err = pubsub.Sub(
610					pipeCtx,
611					clientID,
612					sesh,
613					[]*psub.Channel{
614						psub.NewChannel(name),
615					},
616					*keepAlive,
617				)
618
619				if err != nil && !*clean {
620					_, _ = fmt.Fprintln(sesh.Stderr(), err)
621				}
622			} else if cmd == "pipe" {
623				pipeCmd := flagSet("pipe", sesh)
624				access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
625				public := pipeCmd.Bool("p", false, "Pipe to a public topic")
626				replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
627				clean := pipeCmd.Bool("c", false, "Don't send status messages")
628
629				if !flagCheck(pipeCmd, topic, cmdArgs) {
630					return err
631				}
632
633				if pipeCmd.NArg() == 1 && topic == "" {
634					topic = pipeCmd.Arg(0)
635				}
636
637				logger.Info(
638					"flags parsed",
639					"cmd", cmd,
640					"public", *public,
641					"replay", *replay,
642					"topic", topic,
643					"access", *access,
644					"clean", *clean,
645				)
646
647				var accessList []string
648
649				if *access != "" {
650					accessList = parseArgList(*access)
651				}
652
653				isCreator := topic == ""
654				if isCreator {
655					topic = uuid.NewString()
656				}
657
658				var withoutUser string
659				var name string
660				flagMsg := ""
661
662				if isAdmin && strings.HasPrefix(topic, "/") {
663					name = strings.TrimPrefix(topic, "/")
664				} else {
665					name = toTopic(userName, topic)
666					if *public {
667						name = toPublicTopic(topic)
668						flagMsg = "-p "
669						withoutUser = name
670					} else {
671						withoutUser = topic
672					}
673				}
674
675				var accessListCreator bool
676
677				_, loaded := handler.Access.LoadOrStore(name, accessList)
678				if !loaded {
679					defer func() {
680						handler.Access.Delete(name)
681					}()
682
683					accessListCreator = true
684				}
685
686				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
687					if checkAccess(accessList, userName, sesh) || accessListCreator {
688						name = withoutUser
689					} else if !*public {
690						name = toTopic(userName, withoutUser)
691					} else {
692						topic = uuid.NewString()
693						name = toPublicTopic(topic)
694					}
695				}
696
697				if isCreator && !*clean {
698					fmtTopic := topic
699					if *access != "" {
700						fmtTopic = fmt.Sprintf("%s/%s", userName, topic)
701					}
702
703					_, _ = fmt.Fprintf(
704						sesh,
705						"subscribe to this topic:\n  ssh %s sub %s%s\n",
706						toSshCmd(handler.Cfg),
707						flagMsg,
708						fmtTopic,
709					)
710				}
711
712				readErr, writeErr := pubsub.Pipe(
713					pipeCtx,
714					clientID,
715					sesh,
716					[]*psub.Channel{
717						psub.NewChannel(name),
718					},
719					*replay,
720				)
721
722				if readErr != nil && !*clean {
723					_, _ = fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
724				}
725
726				if writeErr != nil && !*clean {
727					_, _ = fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
728				}
729			}
730
731			return next(sesh)
732		}
733	}
734}