Eric Bower
·
2025-05-25
cli.go
1package pipe
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "log/slog"
10 "slices"
11 "strings"
12 "text/tabwriter"
13 "time"
14
15 "github.com/antoniomika/syncmap"
16 "github.com/google/uuid"
17 "github.com/picosh/pico/pkg/db"
18 "github.com/picosh/pico/pkg/pssh"
19 "github.com/picosh/pico/pkg/shared"
20 psub "github.com/picosh/pubsub"
21 gossh "golang.org/x/crypto/ssh"
22)
23
24func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
25 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
26 cmd.SetOutput(sesh)
27 cmd.Usage = func() {
28 _, _ = fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
29 cmd.PrintDefaults()
30 }
31 return cmd
32}
33
34func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
35 err := cmd.Parse(cmdArgs)
36
37 if err != nil || posArg == "help" {
38 if posArg == "help" {
39 cmd.Usage()
40 }
41 return false
42 }
43 return true
44}
45
46func NewTabWriter(out io.Writer) *tabwriter.Writer {
47 return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
48}
49
50// scope topic to user by prefixing name.
51func toTopic(userName, topic string) string {
52 return fmt.Sprintf("%s/%s", userName, topic)
53}
54
55func toPublicTopic(topic string) string {
56 return fmt.Sprintf("public/%s", topic)
57}
58
59func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
60 if len(clients) == 0 {
61 return ""
62 }
63
64 outputData := fmt.Sprintf(" %s:\r\n", clientType)
65
66 for _, client := range clients {
67 if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
68 continue
69 }
70
71 outputData += fmt.Sprintf(" - %s\r\n", client.ID)
72 }
73
74 return outputData
75}
76
77var helpStr = func(sshCmd string) string {
78 data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
79
80The simplest authenticated pubsub system. Send messages through
81user-defined topics. Topics are private to the authenticated
82ssh user. The default pubsub model is multicast with bidirectional
83blocking, meaning a publisher ("pub") will send its message to all
84subscribers ("sub"). Further, both "pub" and "sub" will wait for
85at least one event to be sent or received. Pipe ("pipe") allows
86for bidirectional messages to be sent between any clients connected
87to a pipe.
88
89Think of these different commands in terms of the direction the
90data is being sent:
91
92- pub => writes to client
93- sub => reads from client
94- pipe => read and write between clients
95`, sshCmd)
96
97 data = strings.ReplaceAll(data, "\n", "\r\n")
98
99 return data
100}
101
102type CliHandler struct {
103 DBPool db.DB
104 Logger *slog.Logger
105 PubSub psub.PubSub
106 Cfg *shared.ConfigSite
107 Waiters *syncmap.Map[string, []string]
108 Access *syncmap.Map[string, []string]
109}
110
111func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
112 return h.Logger
113}
114
115func toSshCmd(cfg *shared.ConfigSite) string {
116 port := ""
117 if cfg.PortOverride != "22" {
118 port = fmt.Sprintf("-p %s ", cfg.PortOverride)
119 }
120 return fmt.Sprintf("%s%s", port, cfg.Domain)
121}
122
123// parseArgList parses a comma separated list of arguments.
124func parseArgList(arg string) []string {
125 argList := strings.Split(arg, ",")
126 for i, acc := range argList {
127 argList[i] = strings.TrimSpace(acc)
128 }
129 return argList
130}
131
132// checkAccess checks if the user has access to a topic based on an access list.
133func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
134 for _, acc := range accessList {
135 if acc == userName {
136 return true
137 }
138
139 if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
140 return true
141 }
142 }
143
144 return false
145}
146
147func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
148 pubsub := handler.PubSub
149
150 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
151 return func(sesh *pssh.SSHServerConnSession) error {
152 ctx := sesh.Context()
153 logger := pssh.GetLogger(sesh)
154 user := pssh.GetUser(sesh)
155
156 args := sesh.Command()
157
158 if len(args) == 0 {
159 _, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
160 return next(sesh)
161 }
162
163 userName := "public"
164
165 userNameAddition := ""
166
167 isAdmin := false
168 impersonate := false
169 if user != nil {
170 isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
171 if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
172 impersonate = true
173 }
174
175 userName = user.Name
176 if user.PublicKey != nil && user.PublicKey.Name != "" {
177 userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
178 }
179 }
180
181 pipeCtx, cancel := context.WithCancel(ctx)
182
183 go func() {
184 defer cancel()
185
186 for {
187 select {
188 case <-pipeCtx.Done():
189 return
190 default:
191 _, err := sesh.SendRequest("ping@pico.sh", false, nil)
192 if err != nil {
193 logger.Error("error sending ping", "err", err)
194 return
195 }
196
197 time.Sleep(5 * time.Second)
198 }
199 }
200 }()
201
202 cmd := strings.TrimSpace(args[0])
203 if cmd == "help" {
204 _, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
205 return next(sesh)
206 } else if cmd == "ls" {
207 if userName == "public" {
208 err := fmt.Errorf("access denied")
209 sesh.Fatal(err)
210 return err
211 }
212
213 topicFilter := fmt.Sprintf("%s/", userName)
214 if isAdmin {
215 topicFilter = ""
216 if len(args) > 1 {
217 topicFilter = args[1]
218 }
219 }
220
221 var channels []*psub.Channel
222 waitingChannels := map[string][]string{}
223
224 for topic, channel := range pubsub.GetChannels() {
225 if strings.HasPrefix(topic, topicFilter) {
226 channels = append(channels, channel)
227 }
228 }
229
230 for channel, clients := range handler.Waiters.Range {
231 if strings.HasPrefix(channel, topicFilter) {
232 waitingChannels[channel] = clients
233 }
234 }
235
236 if len(channels) == 0 && len(waitingChannels) == 0 {
237 _, _ = fmt.Fprintln(sesh, "no pubsub channels found")
238 } else {
239 var outputData string
240 if len(channels) > 0 || len(waitingChannels) > 0 {
241 outputData += "Channel Information\r\n"
242 for _, channel := range channels {
243 extraData := ""
244
245 if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
246 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
247 }
248
249 outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
250 outputData += " Clients:\r\n"
251
252 var pubs []*psub.Client
253 var subs []*psub.Client
254 var pipes []*psub.Client
255
256 for _, client := range channel.GetClients() {
257 switch client.Direction {
258 case psub.ChannelDirectionInput:
259 pubs = append(pubs, client)
260 case psub.ChannelDirectionOutput:
261 subs = append(subs, client)
262 case psub.ChannelDirectionInputOutput:
263 pipes = append(pipes, client)
264 }
265 }
266 outputData += clientInfo(pubs, isAdmin, "Pubs")
267 outputData += clientInfo(subs, isAdmin, "Subs")
268 outputData += clientInfo(pipes, isAdmin, "Pipes")
269 }
270
271 for waitingChannel, channelPubs := range waitingChannels {
272 extraData := ""
273
274 if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
275 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
276 }
277
278 outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
279 outputData += " Clients:\r\n"
280 outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
281 for _, client := range channelPubs {
282 if strings.HasPrefix(client, "admin-") && !isAdmin {
283 continue
284 }
285 outputData += fmt.Sprintf(" - %s\r\n", client)
286 }
287 }
288 }
289
290 _, _ = sesh.Write([]byte(outputData))
291 }
292
293 return next(sesh)
294 }
295
296 topic := ""
297 cmdArgs := args[1:]
298 if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
299 topic = strings.TrimSpace(args[1])
300 cmdArgs = args[2:]
301 }
302
303 logger.Info(
304 "pubsub middleware detected command",
305 "args", args,
306 "cmd", cmd,
307 "topic", topic,
308 "cmdArgs", cmdArgs,
309 )
310
311 uuidStr := uuid.NewString()
312 if impersonate {
313 uuidStr = fmt.Sprintf("admin-%s", uuidStr)
314 }
315
316 clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
317
318 var err error
319
320 if cmd == "pub" {
321 pubCmd := flagSet("pub", sesh)
322 access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
323 empty := pubCmd.Bool("e", false, "Send an empty message to subs")
324 public := pubCmd.Bool("p", false, "Publish message to public topic")
325 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
326 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
327 clean := pubCmd.Bool("c", false, "Don't send status messages")
328
329 if !flagCheck(pubCmd, topic, cmdArgs) {
330 return err
331 }
332
333 if pubCmd.NArg() == 1 && topic == "" {
334 topic = pubCmd.Arg(0)
335 }
336
337 logger.Info(
338 "flags parsed",
339 "cmd", cmd,
340 "empty", *empty,
341 "public", *public,
342 "block", *block,
343 "timeout", *timeout,
344 "topic", topic,
345 "access", *access,
346 "clean", *clean,
347 )
348
349 var accessList []string
350
351 if *access != "" {
352 accessList = parseArgList(*access)
353 }
354
355 var rw io.ReadWriter
356 if *empty {
357 rw = bytes.NewBuffer(make([]byte, 1))
358 } else {
359 rw = sesh
360 }
361
362 if topic == "" {
363 topic = uuid.NewString()
364 }
365
366 var withoutUser string
367 var name string
368 msgFlag := ""
369
370 if isAdmin && strings.HasPrefix(topic, "/") {
371 name = strings.TrimPrefix(topic, "/")
372 } else {
373 name = toTopic(userName, topic)
374 if *public {
375 name = toPublicTopic(topic)
376 msgFlag = "-p "
377 withoutUser = name
378 } else {
379 withoutUser = topic
380 }
381 }
382
383 var accessListCreator bool
384
385 _, loaded := handler.Access.LoadOrStore(name, accessList)
386 if !loaded {
387 defer func() {
388 handler.Access.Delete(name)
389 }()
390
391 accessListCreator = true
392 }
393
394 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
395 if checkAccess(accessList, userName, sesh) || accessListCreator {
396 name = withoutUser
397 } else if !*public {
398 name = toTopic(userName, withoutUser)
399 } else {
400 topic = uuid.NewString()
401 name = toPublicTopic(topic)
402 }
403 }
404
405 if !*clean {
406 _, _ = fmt.Fprintf(
407 sesh,
408 "subscribe to this channel:\n ssh %s sub %s%s\n",
409 toSshCmd(handler.Cfg),
410 msgFlag,
411 topic,
412 )
413 }
414
415 if *block {
416 count := 0
417 for topic, channel := range pubsub.GetChannels() {
418 if topic == name {
419 for _, client := range channel.GetClients() {
420 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
421 count++
422 }
423 }
424 break
425 }
426 }
427
428 tt := *timeout
429 if count == 0 {
430 currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
431 handler.Waiters.Store(name, append(currentWaiters, clientID))
432
433 termMsg := "no subs found ... waiting"
434 if tt > 0 {
435 termMsg += " " + tt.String()
436 }
437
438 if !*clean {
439 _, _ = fmt.Fprintln(sesh, termMsg)
440 }
441
442 ready := make(chan struct{})
443
444 go func() {
445 for {
446 select {
447 case <-pipeCtx.Done():
448 cancel()
449 return
450 case <-time.After(1 * time.Millisecond):
451 count := 0
452 for topic, channel := range pubsub.GetChannels() {
453 if topic == name {
454 for _, client := range channel.GetClients() {
455 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
456 count++
457 }
458 }
459 break
460 }
461 }
462
463 if count > 0 {
464 close(ready)
465 return
466 }
467 }
468 }
469 }()
470
471 select {
472 case <-ready:
473 case <-pipeCtx.Done():
474 case <-time.After(tt):
475 cancel()
476
477 if !*clean {
478 sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
479 } else {
480 err = sesh.Exit(1)
481 if err != nil {
482 logger.Error("error exiting session", "err", err)
483 }
484
485 _ = sesh.Close()
486 }
487 }
488
489 newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
490 newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
491 return cl == clientID
492 })
493 handler.Waiters.Store(name, newWaiters)
494
495 var toDelete []string
496
497 for channel, clients := range handler.Waiters.Range {
498 if len(clients) == 0 {
499 toDelete = append(toDelete, channel)
500 }
501 }
502
503 for _, channel := range toDelete {
504 handler.Waiters.Delete(channel)
505 }
506 }
507 }
508
509 if !*clean {
510 _, _ = fmt.Fprintln(sesh, "sending msg ...")
511 }
512
513 err = pubsub.Pub(
514 pipeCtx,
515 clientID,
516 rw,
517 []*psub.Channel{
518 psub.NewChannel(name),
519 },
520 *block,
521 )
522
523 if !*clean {
524 _, _ = fmt.Fprintln(sesh, "msg sent!")
525 }
526
527 if err != nil && !*clean {
528 _, _ = fmt.Fprintln(sesh.Stderr(), err)
529 }
530 } else if cmd == "sub" {
531 subCmd := flagSet("sub", sesh)
532 access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
533 public := subCmd.Bool("p", false, "Subscribe to a public topic")
534 keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
535 clean := subCmd.Bool("c", false, "Don't send status messages")
536
537 if !flagCheck(subCmd, topic, cmdArgs) {
538 return err
539 }
540
541 if subCmd.NArg() == 1 && topic == "" {
542 topic = subCmd.Arg(0)
543 }
544
545 logger.Info(
546 "flags parsed",
547 "cmd", cmd,
548 "public", *public,
549 "keepAlive", *keepAlive,
550 "topic", topic,
551 "clean", *clean,
552 "access", *access,
553 )
554
555 var accessList []string
556
557 if *access != "" {
558 accessList = parseArgList(*access)
559 }
560
561 var withoutUser string
562 var name string
563
564 if isAdmin && strings.HasPrefix(topic, "/") {
565 name = strings.TrimPrefix(topic, "/")
566 } else {
567 name = toTopic(userName, topic)
568 if *public {
569 name = toPublicTopic(topic)
570 withoutUser = name
571 } else {
572 withoutUser = topic
573 }
574 }
575
576 var accessListCreator bool
577
578 _, loaded := handler.Access.LoadOrStore(name, accessList)
579 if !loaded {
580 defer func() {
581 handler.Access.Delete(name)
582 }()
583
584 accessListCreator = true
585 }
586
587 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
588 if checkAccess(accessList, userName, sesh) || accessListCreator {
589 name = withoutUser
590 } else if !*public {
591 name = toTopic(userName, withoutUser)
592 } else {
593 _, _ = fmt.Fprintln(sesh.Stderr(), "access denied")
594 return err
595 }
596 }
597
598 err = pubsub.Sub(
599 pipeCtx,
600 clientID,
601 sesh,
602 []*psub.Channel{
603 psub.NewChannel(name),
604 },
605 *keepAlive,
606 )
607
608 if err != nil && !*clean {
609 _, _ = fmt.Fprintln(sesh.Stderr(), err)
610 }
611 } else if cmd == "pipe" {
612 pipeCmd := flagSet("pipe", sesh)
613 access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
614 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
615 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
616 clean := pipeCmd.Bool("c", false, "Don't send status messages")
617
618 if !flagCheck(pipeCmd, topic, cmdArgs) {
619 return err
620 }
621
622 if pipeCmd.NArg() == 1 && topic == "" {
623 topic = pipeCmd.Arg(0)
624 }
625
626 logger.Info(
627 "flags parsed",
628 "cmd", cmd,
629 "public", *public,
630 "replay", *replay,
631 "topic", topic,
632 "access", *access,
633 "clean", *clean,
634 )
635
636 var accessList []string
637
638 if *access != "" {
639 accessList = parseArgList(*access)
640 }
641
642 isCreator := topic == ""
643 if isCreator {
644 topic = uuid.NewString()
645 }
646
647 var withoutUser string
648 var name string
649 flagMsg := ""
650
651 if isAdmin && strings.HasPrefix(topic, "/") {
652 name = strings.TrimPrefix(topic, "/")
653 } else {
654 name = toTopic(userName, topic)
655 if *public {
656 name = toPublicTopic(topic)
657 flagMsg = "-p "
658 withoutUser = name
659 } else {
660 withoutUser = topic
661 }
662 }
663
664 var accessListCreator bool
665
666 _, loaded := handler.Access.LoadOrStore(name, accessList)
667 if !loaded {
668 defer func() {
669 handler.Access.Delete(name)
670 }()
671
672 accessListCreator = true
673 }
674
675 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
676 if checkAccess(accessList, userName, sesh) || accessListCreator {
677 name = withoutUser
678 } else if !*public {
679 name = toTopic(userName, withoutUser)
680 } else {
681 topic = uuid.NewString()
682 name = toPublicTopic(topic)
683 }
684 }
685
686 if isCreator && !*clean {
687 _, _ = fmt.Fprintf(
688 sesh,
689 "subscribe to this topic:\n ssh %s sub %s%s\n",
690 toSshCmd(handler.Cfg),
691 flagMsg,
692 topic,
693 )
694 }
695
696 readErr, writeErr := pubsub.Pipe(
697 pipeCtx,
698 clientID,
699 sesh,
700 []*psub.Channel{
701 psub.NewChannel(name),
702 },
703 *replay,
704 )
705
706 if readErr != nil && !*clean {
707 _, _ = fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
708 }
709
710 if writeErr != nil && !*clean {
711 _, _ = fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
712 }
713 }
714
715 return next(sesh)
716 }
717 }
718}