Antonio Mika
·
2025-03-15
cli.go
1package pipe
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "log/slog"
10 "slices"
11 "strings"
12 "text/tabwriter"
13 "time"
14
15 "github.com/antoniomika/syncmap"
16 "github.com/google/uuid"
17 "github.com/picosh/pico/pkg/db"
18 "github.com/picosh/pico/pkg/pssh"
19 "github.com/picosh/pico/pkg/shared"
20 psub "github.com/picosh/pubsub"
21 gossh "golang.org/x/crypto/ssh"
22)
23
24func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
25 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
26 cmd.SetOutput(sesh)
27 cmd.Usage = func() {
28 fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
29 cmd.PrintDefaults()
30 }
31 return cmd
32}
33
34func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
35 err := cmd.Parse(cmdArgs)
36
37 if err != nil || posArg == "help" {
38 if posArg == "help" {
39 cmd.Usage()
40 }
41 return false
42 }
43 return true
44}
45
46func NewTabWriter(out io.Writer) *tabwriter.Writer {
47 return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
48}
49
50// scope topic to user by prefixing name.
51func toTopic(userName, topic string) string {
52 return fmt.Sprintf("%s/%s", userName, topic)
53}
54
55func toPublicTopic(topic string) string {
56 return fmt.Sprintf("public/%s", topic)
57}
58
59func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
60 if len(clients) == 0 {
61 return ""
62 }
63
64 outputData := fmt.Sprintf(" %s:\r\n", clientType)
65
66 for _, client := range clients {
67 if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
68 continue
69 }
70
71 outputData += fmt.Sprintf(" - %s\r\n", client.ID)
72 }
73
74 return outputData
75}
76
77var helpStr = func(sshCmd string) string {
78 data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
79
80The simplest authenticated pubsub system. Send messages through
81user-defined topics. Topics are private to the authenticated
82ssh user. The default pubsub model is multicast with bidirectional
83blocking, meaning a publisher ("pub") will send its message to all
84subscribers ("sub"). Further, both "pub" and "sub" will wait for
85at least one event to be sent or received. Pipe ("pipe") allows
86for bidirectional messages to be sent between any clients connected
87to a pipe.
88
89Think of these different commands in terms of the direction the
90data is being sent:
91
92- pub => writes to client
93- sub => reads from client
94- pipe => read and write between clients
95`, sshCmd)
96
97 data = strings.ReplaceAll(data, "\n", "\r\n")
98
99 return data
100}
101
102type CliHandler struct {
103 DBPool db.DB
104 Logger *slog.Logger
105 PubSub psub.PubSub
106 Cfg *shared.ConfigSite
107 Waiters *syncmap.Map[string, []string]
108 Access *syncmap.Map[string, []string]
109}
110
111func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
112 return h.Logger
113}
114
115func toSshCmd(cfg *shared.ConfigSite) string {
116 port := ""
117 if cfg.PortOverride != "22" {
118 port = fmt.Sprintf("-p %s ", cfg.PortOverride)
119 }
120 return fmt.Sprintf("%s%s", port, cfg.Domain)
121}
122
123// parseArgList parses a comma separated list of arguments.
124func parseArgList(arg string) []string {
125 argList := strings.Split(arg, ",")
126 for i, acc := range argList {
127 argList[i] = strings.TrimSpace(acc)
128 }
129 return argList
130}
131
132// checkAccess checks if the user has access to a topic based on an access list.
133func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
134 for _, acc := range accessList {
135 if acc == userName {
136 return true
137 }
138
139 if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
140 return true
141 }
142 }
143
144 return false
145}
146
147func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
148 pubsub := handler.PubSub
149
150 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
151 return func(sesh *pssh.SSHServerConnSession) error {
152 ctx := sesh.Context()
153 logger := pssh.GetLogger(sesh)
154 user := pssh.GetUser(sesh)
155
156 args := sesh.Command()
157
158 if len(args) == 0 {
159 fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
160 return next(sesh)
161 }
162
163 userName := "public"
164
165 userNameAddition := ""
166
167 isAdmin := false
168 impersonate := false
169 if user != nil {
170 isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
171 if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
172 impersonate = true
173 }
174
175 userName = user.Name
176 if user.PublicKey != nil && user.PublicKey.Name != "" {
177 userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
178 }
179 }
180
181 pipeCtx, cancel := context.WithCancel(ctx)
182
183 go func() {
184 defer cancel()
185
186 for {
187 select {
188 case <-pipeCtx.Done():
189 return
190 default:
191 _, err := sesh.SendRequest("ping@pico.sh", false, nil)
192 if err != nil {
193 logger.Error("error sending ping", "err", err)
194 return
195 }
196
197 time.Sleep(5 * time.Second)
198 }
199 }
200 }()
201
202 cmd := strings.TrimSpace(args[0])
203 if cmd == "help" {
204 fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
205 return next(sesh)
206 } else if cmd == "ls" {
207 if userName == "public" {
208 err := fmt.Errorf("access denied")
209 sesh.Fatal(err)
210 return err
211 }
212
213 topicFilter := fmt.Sprintf("%s/", userName)
214 if isAdmin {
215 topicFilter = ""
216 if len(args) > 1 {
217 topicFilter = args[1]
218 }
219 }
220
221 var channels []*psub.Channel
222 waitingChannels := map[string][]string{}
223
224 for topic, channel := range pubsub.GetChannels() {
225 if strings.HasPrefix(topic, topicFilter) {
226 channels = append(channels, channel)
227 }
228 }
229
230 for channel, clients := range handler.Waiters.Range {
231 if strings.HasPrefix(channel, topicFilter) {
232 waitingChannels[channel] = clients
233 }
234 }
235
236 if len(channels) == 0 && len(waitingChannels) == 0 {
237 fmt.Fprintln(sesh, "no pubsub channels found")
238 } else {
239 var outputData string
240 if len(channels) > 0 || len(waitingChannels) > 0 {
241 outputData += "Channel Information\r\n"
242 for _, channel := range channels {
243 extraData := ""
244
245 if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
246 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
247 }
248
249 outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
250 outputData += " Clients:\r\n"
251
252 var pubs []*psub.Client
253 var subs []*psub.Client
254 var pipes []*psub.Client
255
256 for _, client := range channel.GetClients() {
257 if client.Direction == psub.ChannelDirectionInput {
258 pubs = append(pubs, client)
259 } else if client.Direction == psub.ChannelDirectionOutput {
260 subs = append(subs, client)
261 } else if client.Direction == psub.ChannelDirectionInputOutput {
262 pipes = append(pipes, client)
263 }
264 }
265 outputData += clientInfo(pubs, isAdmin, "Pubs")
266 outputData += clientInfo(subs, isAdmin, "Subs")
267 outputData += clientInfo(pipes, isAdmin, "Pipes")
268 }
269
270 for waitingChannel, channelPubs := range waitingChannels {
271 extraData := ""
272
273 if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
274 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
275 }
276
277 outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
278 outputData += " Clients:\r\n"
279 outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
280 for _, client := range channelPubs {
281 if strings.HasPrefix(client, "admin-") && !isAdmin {
282 continue
283 }
284 outputData += fmt.Sprintf(" - %s\r\n", client)
285 }
286 }
287 }
288
289 _, _ = sesh.Write([]byte(outputData))
290 }
291
292 return next(sesh)
293 }
294
295 topic := ""
296 cmdArgs := args[1:]
297 if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
298 topic = strings.TrimSpace(args[1])
299 cmdArgs = args[2:]
300 }
301
302 logger.Info(
303 "pubsub middleware detected command",
304 "args", args,
305 "cmd", cmd,
306 "topic", topic,
307 "cmdArgs", cmdArgs,
308 )
309
310 uuidStr := uuid.NewString()
311 if impersonate {
312 uuidStr = fmt.Sprintf("admin-%s", uuidStr)
313 }
314
315 clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
316
317 var err error
318
319 if cmd == "pub" {
320 pubCmd := flagSet("pub", sesh)
321 access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
322 empty := pubCmd.Bool("e", false, "Send an empty message to subs")
323 public := pubCmd.Bool("p", false, "Publish message to public topic")
324 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
325 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
326 clean := pubCmd.Bool("c", false, "Don't send status messages")
327
328 if !flagCheck(pubCmd, topic, cmdArgs) {
329 return err
330 }
331
332 if pubCmd.NArg() == 1 && topic == "" {
333 topic = pubCmd.Arg(0)
334 }
335
336 logger.Info(
337 "flags parsed",
338 "cmd", cmd,
339 "empty", *empty,
340 "public", *public,
341 "block", *block,
342 "timeout", *timeout,
343 "topic", topic,
344 "access", *access,
345 "clean", *clean,
346 )
347
348 var accessList []string
349
350 if *access != "" {
351 accessList = parseArgList(*access)
352 }
353
354 var rw io.ReadWriter
355 if *empty {
356 rw = bytes.NewBuffer(make([]byte, 1))
357 } else {
358 rw = sesh
359 }
360
361 if topic == "" {
362 topic = uuid.NewString()
363 }
364
365 var withoutUser string
366 var name string
367 msgFlag := ""
368
369 if isAdmin && strings.HasPrefix(topic, "/") {
370 name = strings.TrimPrefix(topic, "/")
371 } else {
372 name = toTopic(userName, topic)
373 if *public {
374 name = toPublicTopic(topic)
375 msgFlag = "-p "
376 withoutUser = name
377 } else {
378 withoutUser = topic
379 }
380 }
381
382 var accessListCreator bool
383
384 _, loaded := handler.Access.LoadOrStore(name, accessList)
385 if !loaded {
386 defer func() {
387 handler.Access.Delete(name)
388 }()
389
390 accessListCreator = true
391 }
392
393 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
394 if checkAccess(accessList, userName, sesh) || accessListCreator {
395 name = withoutUser
396 } else if !*public {
397 name = toTopic(userName, withoutUser)
398 } else {
399 topic = uuid.NewString()
400 name = toPublicTopic(topic)
401 }
402 }
403
404 if !*clean {
405 fmt.Fprintf(
406 sesh,
407 "subscribe to this channel:\n ssh %s sub %s%s\n",
408 toSshCmd(handler.Cfg),
409 msgFlag,
410 topic,
411 )
412 }
413
414 if *block {
415 count := 0
416 for topic, channel := range pubsub.GetChannels() {
417 if topic == name {
418 for _, client := range channel.GetClients() {
419 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
420 count++
421 }
422 }
423 break
424 }
425 }
426
427 tt := *timeout
428 if count == 0 {
429 currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
430 handler.Waiters.Store(name, append(currentWaiters, clientID))
431
432 termMsg := "no subs found ... waiting"
433 if tt > 0 {
434 termMsg += " " + tt.String()
435 }
436
437 if !*clean {
438 fmt.Fprintln(sesh, termMsg)
439 }
440
441 ready := make(chan struct{})
442
443 go func() {
444 for {
445 select {
446 case <-pipeCtx.Done():
447 cancel()
448 return
449 case <-time.After(1 * time.Millisecond):
450 count := 0
451 for topic, channel := range pubsub.GetChannels() {
452 if topic == name {
453 for _, client := range channel.GetClients() {
454 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
455 count++
456 }
457 }
458 break
459 }
460 }
461
462 if count > 0 {
463 close(ready)
464 return
465 }
466 }
467 }
468 }()
469
470 select {
471 case <-ready:
472 case <-pipeCtx.Done():
473 case <-time.After(tt):
474 cancel()
475
476 if !*clean {
477 sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
478 } else {
479 err = sesh.Exit(1)
480 if err != nil {
481 logger.Error("error exiting session", "err", err)
482 }
483
484 sesh.Close()
485 }
486 }
487
488 newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
489 newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
490 return cl == clientID
491 })
492 handler.Waiters.Store(name, newWaiters)
493
494 var toDelete []string
495
496 for channel, clients := range handler.Waiters.Range {
497 if len(clients) == 0 {
498 toDelete = append(toDelete, channel)
499 }
500 }
501
502 for _, channel := range toDelete {
503 handler.Waiters.Delete(channel)
504 }
505 }
506 }
507
508 if !*clean {
509 fmt.Fprintln(sesh, "sending msg ...")
510 }
511
512 err = pubsub.Pub(
513 pipeCtx,
514 clientID,
515 rw,
516 []*psub.Channel{
517 psub.NewChannel(name),
518 },
519 *block,
520 )
521
522 if !*clean {
523 fmt.Fprintln(sesh, "msg sent!")
524 }
525
526 if err != nil && !*clean {
527 fmt.Fprintln(sesh.Stderr(), err)
528 }
529 } else if cmd == "sub" {
530 subCmd := flagSet("sub", sesh)
531 access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
532 public := subCmd.Bool("p", false, "Subscribe to a public topic")
533 keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
534 clean := subCmd.Bool("c", false, "Don't send status messages")
535
536 if !flagCheck(subCmd, topic, cmdArgs) {
537 return err
538 }
539
540 if subCmd.NArg() == 1 && topic == "" {
541 topic = subCmd.Arg(0)
542 }
543
544 logger.Info(
545 "flags parsed",
546 "cmd", cmd,
547 "public", *public,
548 "keepAlive", *keepAlive,
549 "topic", topic,
550 "clean", *clean,
551 "access", *access,
552 )
553
554 var accessList []string
555
556 if *access != "" {
557 accessList = parseArgList(*access)
558 }
559
560 var withoutUser string
561 var name string
562
563 if isAdmin && strings.HasPrefix(topic, "/") {
564 name = strings.TrimPrefix(topic, "/")
565 } else {
566 name = toTopic(userName, topic)
567 if *public {
568 name = toPublicTopic(topic)
569 withoutUser = name
570 } else {
571 withoutUser = topic
572 }
573 }
574
575 var accessListCreator bool
576
577 _, loaded := handler.Access.LoadOrStore(name, accessList)
578 if !loaded {
579 defer func() {
580 handler.Access.Delete(name)
581 }()
582
583 accessListCreator = true
584 }
585
586 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
587 if checkAccess(accessList, userName, sesh) || accessListCreator {
588 name = withoutUser
589 } else if !*public {
590 name = toTopic(userName, withoutUser)
591 } else {
592 fmt.Fprintln(sesh.Stderr(), "access denied")
593 return err
594 }
595 }
596
597 err = pubsub.Sub(
598 pipeCtx,
599 clientID,
600 sesh,
601 []*psub.Channel{
602 psub.NewChannel(name),
603 },
604 *keepAlive,
605 )
606
607 if err != nil && !*clean {
608 fmt.Fprintln(sesh.Stderr(), err)
609 }
610 } else if cmd == "pipe" {
611 pipeCmd := flagSet("pipe", sesh)
612 access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
613 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
614 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
615 clean := pipeCmd.Bool("c", false, "Don't send status messages")
616
617 if !flagCheck(pipeCmd, topic, cmdArgs) {
618 return err
619 }
620
621 if pipeCmd.NArg() == 1 && topic == "" {
622 topic = pipeCmd.Arg(0)
623 }
624
625 logger.Info(
626 "flags parsed",
627 "cmd", cmd,
628 "public", *public,
629 "replay", *replay,
630 "topic", topic,
631 "access", *access,
632 "clean", *clean,
633 )
634
635 var accessList []string
636
637 if *access != "" {
638 accessList = parseArgList(*access)
639 }
640
641 isCreator := topic == ""
642 if isCreator {
643 topic = uuid.NewString()
644 }
645
646 var withoutUser string
647 var name string
648 flagMsg := ""
649
650 if isAdmin && strings.HasPrefix(topic, "/") {
651 name = strings.TrimPrefix(topic, "/")
652 } else {
653 name = toTopic(userName, topic)
654 if *public {
655 name = toPublicTopic(topic)
656 flagMsg = "-p "
657 withoutUser = name
658 } else {
659 withoutUser = topic
660 }
661 }
662
663 var accessListCreator bool
664
665 _, loaded := handler.Access.LoadOrStore(name, accessList)
666 if !loaded {
667 defer func() {
668 handler.Access.Delete(name)
669 }()
670
671 accessListCreator = true
672 }
673
674 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
675 if checkAccess(accessList, userName, sesh) || accessListCreator {
676 name = withoutUser
677 } else if !*public {
678 name = toTopic(userName, withoutUser)
679 } else {
680 topic = uuid.NewString()
681 name = toPublicTopic(topic)
682 }
683 }
684
685 if isCreator && !*clean {
686 fmt.Fprintf(
687 sesh,
688 "subscribe to this topic:\n ssh %s sub %s%s\n",
689 toSshCmd(handler.Cfg),
690 flagMsg,
691 topic,
692 )
693 }
694
695 readErr, writeErr := pubsub.Pipe(
696 pipeCtx,
697 clientID,
698 sesh,
699 []*psub.Channel{
700 psub.NewChannel(name),
701 },
702 *replay,
703 )
704
705 if readErr != nil && !*clean {
706 fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
707 }
708
709 if writeErr != nil && !*clean {
710 fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
711 }
712 }
713
714 return next(sesh)
715 }
716 }
717}