Eric Bower
·
2025-11-28
cli.go
1package pipe
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "log/slog"
10 "slices"
11 "strings"
12 "text/tabwriter"
13 "time"
14
15 "github.com/antoniomika/syncmap"
16 "github.com/google/uuid"
17 "github.com/picosh/pico/pkg/db"
18 "github.com/picosh/pico/pkg/pssh"
19 "github.com/picosh/pico/pkg/shared"
20 psub "github.com/picosh/pubsub"
21 gossh "golang.org/x/crypto/ssh"
22)
23
24func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
25 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
26 cmd.SetOutput(sesh)
27 cmd.Usage = func() {
28 _, _ = fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
29 cmd.PrintDefaults()
30 }
31 return cmd
32}
33
34func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
35 err := cmd.Parse(cmdArgs)
36
37 if err != nil || posArg == "help" {
38 if posArg == "help" {
39 cmd.Usage()
40 }
41 return false
42 }
43 return true
44}
45
46func NewTabWriter(out io.Writer) *tabwriter.Writer {
47 return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
48}
49
50// scope topic to user by prefixing name.
51func toTopic(userName, topic string) string {
52 if strings.HasPrefix(topic, userName+"/") {
53 return topic
54 }
55 return fmt.Sprintf("%s/%s", userName, topic)
56}
57
58func toPublicTopic(topic string) string {
59 if strings.HasPrefix(topic, "public/") {
60 return topic
61 }
62 return fmt.Sprintf("public/%s", topic)
63}
64
65func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
66 if len(clients) == 0 {
67 return ""
68 }
69
70 outputData := fmt.Sprintf(" %s:\r\n", clientType)
71
72 for _, client := range clients {
73 if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
74 continue
75 }
76
77 outputData += fmt.Sprintf(" - %s\r\n", client.ID)
78 }
79
80 return outputData
81}
82
83var helpStr = func(sshCmd string) string {
84 data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
85
86The simplest authenticated pubsub system. Send messages through
87user-defined topics. Topics are private to the authenticated
88ssh user. The default pubsub model is multicast with bidirectional
89blocking, meaning a publisher ("pub") will send its message to all
90subscribers ("sub"). Further, both "pub" and "sub" will wait for
91at least one event to be sent or received. Pipe ("pipe") allows
92for bidirectional messages to be sent between any clients connected
93to a pipe.
94
95Think of these different commands in terms of the direction the
96data is being sent:
97
98- pub => writes to client
99- sub => reads from client
100- pipe => read and write between clients
101`, sshCmd)
102
103 data = strings.ReplaceAll(data, "\n", "\r\n")
104
105 return data
106}
107
108type CliHandler struct {
109 DBPool db.DB
110 Logger *slog.Logger
111 PubSub psub.PubSub
112 Cfg *shared.ConfigSite
113 Waiters *syncmap.Map[string, []string]
114 Access *syncmap.Map[string, []string]
115}
116
117func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
118 return h.Logger
119}
120
121func toSshCmd(cfg *shared.ConfigSite) string {
122 port := ""
123 if cfg.PortOverride != "22" {
124 port = fmt.Sprintf("-p %s ", cfg.PortOverride)
125 }
126 return fmt.Sprintf("%s%s", port, cfg.Domain)
127}
128
129// parseArgList parses a comma separated list of arguments.
130func parseArgList(arg string) []string {
131 argList := strings.Split(arg, ",")
132 for i, acc := range argList {
133 argList[i] = strings.TrimSpace(acc)
134 }
135 return argList
136}
137
138// checkAccess checks if the user has access to a topic based on an access list.
139func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
140 for _, acc := range accessList {
141 if acc == userName {
142 return true
143 }
144
145 if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
146 return true
147 }
148 }
149
150 return false
151}
152
153func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
154 pubsub := handler.PubSub
155
156 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
157 return func(sesh *pssh.SSHServerConnSession) error {
158 ctx := sesh.Context()
159 logger := pssh.GetLogger(sesh)
160 user := pssh.GetUser(sesh)
161
162 args := sesh.Command()
163
164 if len(args) == 0 {
165 _, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
166 return next(sesh)
167 }
168
169 userName := "public"
170
171 userNameAddition := ""
172
173 isAdmin := false
174 impersonate := false
175 if user != nil {
176 isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
177 if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
178 impersonate = true
179 }
180
181 userName = user.Name
182 if user.PublicKey != nil && user.PublicKey.Name != "" {
183 userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
184 }
185 }
186
187 pipeCtx, cancel := context.WithCancel(ctx)
188
189 go func() {
190 defer cancel()
191
192 for {
193 select {
194 case <-pipeCtx.Done():
195 return
196 default:
197 _, err := sesh.SendRequest("ping@pico.sh", false, nil)
198 if err != nil {
199 logger.Error("error sending ping", "err", err)
200 return
201 }
202
203 time.Sleep(5 * time.Second)
204 }
205 }
206 }()
207
208 cmd := strings.TrimSpace(args[0])
209 if cmd == "help" {
210 _, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
211 return next(sesh)
212 } else if cmd == "ls" {
213 if userName == "public" {
214 err := fmt.Errorf("access denied")
215 sesh.Fatal(err)
216 return err
217 }
218
219 topicFilter := fmt.Sprintf("%s/", userName)
220 if isAdmin {
221 topicFilter = ""
222 if len(args) > 1 {
223 topicFilter = args[1]
224 }
225 }
226
227 var channels []*psub.Channel
228 waitingChannels := map[string][]string{}
229
230 for topic, channel := range pubsub.GetChannels() {
231 if strings.HasPrefix(topic, topicFilter) {
232 channels = append(channels, channel)
233 }
234 }
235
236 for channel, clients := range handler.Waiters.Range {
237 if strings.HasPrefix(channel, topicFilter) {
238 waitingChannels[channel] = clients
239 }
240 }
241
242 if len(channels) == 0 && len(waitingChannels) == 0 {
243 _, _ = fmt.Fprintln(sesh, "no pubsub channels found")
244 } else {
245 var outputData string
246 if len(channels) > 0 || len(waitingChannels) > 0 {
247 outputData += "Channel Information\r\n"
248 for _, channel := range channels {
249 extraData := ""
250
251 if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
252 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
253 }
254
255 outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
256 outputData += " Clients:\r\n"
257
258 var pubs []*psub.Client
259 var subs []*psub.Client
260 var pipes []*psub.Client
261
262 for _, client := range channel.GetClients() {
263 switch client.Direction {
264 case psub.ChannelDirectionInput:
265 pubs = append(pubs, client)
266 case psub.ChannelDirectionOutput:
267 subs = append(subs, client)
268 case psub.ChannelDirectionInputOutput:
269 pipes = append(pipes, client)
270 }
271 }
272 outputData += clientInfo(pubs, isAdmin, "Pubs")
273 outputData += clientInfo(subs, isAdmin, "Subs")
274 outputData += clientInfo(pipes, isAdmin, "Pipes")
275 }
276
277 for waitingChannel, channelPubs := range waitingChannels {
278 extraData := ""
279
280 if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
281 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
282 }
283
284 outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
285 outputData += " Clients:\r\n"
286 outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
287 for _, client := range channelPubs {
288 if strings.HasPrefix(client, "admin-") && !isAdmin {
289 continue
290 }
291 outputData += fmt.Sprintf(" - %s\r\n", client)
292 }
293 }
294 }
295
296 _, _ = sesh.Write([]byte(outputData))
297 }
298
299 return next(sesh)
300 }
301
302 topic := ""
303 cmdArgs := args[1:]
304 if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
305 topic = strings.TrimSpace(args[1])
306 cmdArgs = args[2:]
307 }
308
309 logger.Info(
310 "pubsub middleware detected command",
311 "args", args,
312 "cmd", cmd,
313 "topic", topic,
314 "cmdArgs", cmdArgs,
315 )
316
317 uuidStr := uuid.NewString()
318 if impersonate {
319 uuidStr = fmt.Sprintf("admin-%s", uuidStr)
320 }
321
322 clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
323
324 var err error
325
326 if cmd == "pub" {
327 pubCmd := flagSet("pub", sesh)
328 access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
329 empty := pubCmd.Bool("e", false, "Send an empty message to subs")
330 public := pubCmd.Bool("p", false, "Publish message to public topic")
331 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
332 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
333 clean := pubCmd.Bool("c", false, "Don't send status messages")
334
335 if !flagCheck(pubCmd, topic, cmdArgs) {
336 return err
337 }
338
339 if pubCmd.NArg() == 1 && topic == "" {
340 topic = pubCmd.Arg(0)
341 }
342
343 logger.Info(
344 "flags parsed",
345 "cmd", cmd,
346 "empty", *empty,
347 "public", *public,
348 "block", *block,
349 "timeout", *timeout,
350 "topic", topic,
351 "access", *access,
352 "clean", *clean,
353 )
354
355 var accessList []string
356
357 if *access != "" {
358 accessList = parseArgList(*access)
359 }
360
361 var rw io.ReadWriter
362 if *empty {
363 rw = bytes.NewBuffer(make([]byte, 1))
364 } else {
365 rw = sesh
366 }
367
368 if topic == "" {
369 topic = uuid.NewString()
370 }
371
372 var withoutUser string
373 var name string
374 msgFlag := ""
375
376 if isAdmin && strings.HasPrefix(topic, "/") {
377 name = strings.TrimPrefix(topic, "/")
378 } else {
379 name = toTopic(userName, topic)
380 if *public {
381 name = toPublicTopic(topic)
382 msgFlag = "-p "
383 withoutUser = name
384 } else {
385 withoutUser = topic
386 }
387 }
388
389 var accessListCreator bool
390
391 _, loaded := handler.Access.LoadOrStore(name, accessList)
392 if !loaded {
393 defer func() {
394 handler.Access.Delete(name)
395 }()
396
397 accessListCreator = true
398 }
399
400 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
401 if checkAccess(accessList, userName, sesh) || accessListCreator {
402 name = withoutUser
403 } else if !*public {
404 name = toTopic(userName, withoutUser)
405 } else {
406 topic = uuid.NewString()
407 name = toPublicTopic(topic)
408 }
409 }
410
411 if !*clean {
412 fmtTopic := topic
413 if *access != "" {
414 fmtTopic = fmt.Sprintf("%s/%s", userName, topic)
415 }
416
417 _, _ = fmt.Fprintf(
418 sesh,
419 "subscribe to this channel:\n ssh %s sub %s%s\n",
420 toSshCmd(handler.Cfg),
421 msgFlag,
422 fmtTopic,
423 )
424 }
425
426 if *block {
427 count := 0
428 for topic, channel := range pubsub.GetChannels() {
429 if topic == name {
430 for _, client := range channel.GetClients() {
431 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
432 count++
433 }
434 }
435 break
436 }
437 }
438
439 tt := *timeout
440 if count == 0 {
441 currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
442 handler.Waiters.Store(name, append(currentWaiters, clientID))
443
444 termMsg := "no subs found ... waiting"
445 if tt > 0 {
446 termMsg += " " + tt.String()
447 }
448
449 if !*clean {
450 _, _ = fmt.Fprintln(sesh, termMsg)
451 }
452
453 ready := make(chan struct{})
454
455 go func() {
456 for {
457 select {
458 case <-pipeCtx.Done():
459 cancel()
460 return
461 case <-time.After(1 * time.Millisecond):
462 count := 0
463 for topic, channel := range pubsub.GetChannels() {
464 if topic == name {
465 for _, client := range channel.GetClients() {
466 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
467 count++
468 }
469 }
470 break
471 }
472 }
473
474 if count > 0 {
475 close(ready)
476 return
477 }
478 }
479 }
480 }()
481
482 select {
483 case <-ready:
484 case <-pipeCtx.Done():
485 case <-time.After(tt):
486 cancel()
487
488 if !*clean {
489 sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
490 } else {
491 err = sesh.Exit(1)
492 if err != nil {
493 logger.Error("error exiting session", "err", err)
494 }
495
496 _ = sesh.Close()
497 }
498 }
499
500 newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
501 newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
502 return cl == clientID
503 })
504 handler.Waiters.Store(name, newWaiters)
505
506 var toDelete []string
507
508 for channel, clients := range handler.Waiters.Range {
509 if len(clients) == 0 {
510 toDelete = append(toDelete, channel)
511 }
512 }
513
514 for _, channel := range toDelete {
515 handler.Waiters.Delete(channel)
516 }
517 }
518 }
519
520 if !*clean {
521 _, _ = fmt.Fprintln(sesh, "sending msg ...")
522 }
523
524 err = pubsub.Pub(
525 pipeCtx,
526 clientID,
527 rw,
528 []*psub.Channel{
529 psub.NewChannel(name),
530 },
531 *block,
532 )
533
534 if !*clean {
535 _, _ = fmt.Fprintln(sesh, "msg sent!")
536 }
537
538 if err != nil && !*clean {
539 _, _ = fmt.Fprintln(sesh.Stderr(), err)
540 }
541 } else if cmd == "sub" {
542 subCmd := flagSet("sub", sesh)
543 access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
544 public := subCmd.Bool("p", false, "Subscribe to a public topic")
545 keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
546 clean := subCmd.Bool("c", false, "Don't send status messages")
547
548 if !flagCheck(subCmd, topic, cmdArgs) {
549 return err
550 }
551
552 if subCmd.NArg() == 1 && topic == "" {
553 topic = subCmd.Arg(0)
554 }
555
556 logger.Info(
557 "flags parsed",
558 "cmd", cmd,
559 "public", *public,
560 "keepAlive", *keepAlive,
561 "topic", topic,
562 "clean", *clean,
563 "access", *access,
564 )
565
566 var accessList []string
567
568 if *access != "" {
569 accessList = parseArgList(*access)
570 }
571
572 var withoutUser string
573 var name string
574
575 if isAdmin && strings.HasPrefix(topic, "/") {
576 name = strings.TrimPrefix(topic, "/")
577 } else {
578 name = toTopic(userName, topic)
579 if *public {
580 name = toPublicTopic(topic)
581 withoutUser = name
582 } else {
583 withoutUser = topic
584 }
585 }
586
587 var accessListCreator bool
588
589 _, loaded := handler.Access.LoadOrStore(name, accessList)
590 if !loaded {
591 defer func() {
592 handler.Access.Delete(name)
593 }()
594
595 accessListCreator = true
596 }
597
598 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
599 if checkAccess(accessList, userName, sesh) || accessListCreator {
600 name = withoutUser
601 } else if !*public {
602 name = toTopic(userName, withoutUser)
603 } else {
604 _, _ = fmt.Fprintln(sesh.Stderr(), "access denied")
605 return err
606 }
607 }
608
609 err = pubsub.Sub(
610 pipeCtx,
611 clientID,
612 sesh,
613 []*psub.Channel{
614 psub.NewChannel(name),
615 },
616 *keepAlive,
617 )
618
619 if err != nil && !*clean {
620 _, _ = fmt.Fprintln(sesh.Stderr(), err)
621 }
622 } else if cmd == "pipe" {
623 pipeCmd := flagSet("pipe", sesh)
624 access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
625 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
626 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
627 clean := pipeCmd.Bool("c", false, "Don't send status messages")
628
629 if !flagCheck(pipeCmd, topic, cmdArgs) {
630 return err
631 }
632
633 if pipeCmd.NArg() == 1 && topic == "" {
634 topic = pipeCmd.Arg(0)
635 }
636
637 logger.Info(
638 "flags parsed",
639 "cmd", cmd,
640 "public", *public,
641 "replay", *replay,
642 "topic", topic,
643 "access", *access,
644 "clean", *clean,
645 )
646
647 var accessList []string
648
649 if *access != "" {
650 accessList = parseArgList(*access)
651 }
652
653 isCreator := topic == ""
654 if isCreator {
655 topic = uuid.NewString()
656 }
657
658 var withoutUser string
659 var name string
660 flagMsg := ""
661
662 if isAdmin && strings.HasPrefix(topic, "/") {
663 name = strings.TrimPrefix(topic, "/")
664 } else {
665 name = toTopic(userName, topic)
666 if *public {
667 name = toPublicTopic(topic)
668 flagMsg = "-p "
669 withoutUser = name
670 } else {
671 withoutUser = topic
672 }
673 }
674
675 var accessListCreator bool
676
677 _, loaded := handler.Access.LoadOrStore(name, accessList)
678 if !loaded {
679 defer func() {
680 handler.Access.Delete(name)
681 }()
682
683 accessListCreator = true
684 }
685
686 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
687 if checkAccess(accessList, userName, sesh) || accessListCreator {
688 name = withoutUser
689 } else if !*public {
690 name = toTopic(userName, withoutUser)
691 } else {
692 topic = uuid.NewString()
693 name = toPublicTopic(topic)
694 }
695 }
696
697 if isCreator && !*clean {
698 fmtTopic := topic
699 if *access != "" {
700 fmtTopic = fmt.Sprintf("%s/%s", userName, topic)
701 }
702
703 _, _ = fmt.Fprintf(
704 sesh,
705 "subscribe to this topic:\n ssh %s sub %s%s\n",
706 toSshCmd(handler.Cfg),
707 flagMsg,
708 fmtTopic,
709 )
710 }
711
712 readErr, writeErr := pubsub.Pipe(
713 pipeCtx,
714 clientID,
715 sesh,
716 []*psub.Channel{
717 psub.NewChannel(name),
718 },
719 *replay,
720 )
721
722 if readErr != nil && !*clean {
723 _, _ = fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
724 }
725
726 if writeErr != nil && !*clean {
727 _, _ = fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
728 }
729 }
730
731 return next(sesh)
732 }
733 }
734}