- commit
- 1d500f9
- parent
- 9658c7d
- author
- Eric Bower
- date
- 2025-12-25 23:08:50 -0500 EST
refactor(pipe): cli code organization Use a struct with methods to abstract cli middleware code into discrete functions
1 files changed,
+611,
-568
+611,
-568
1@@ -21,138 +21,7 @@ import (
2 gossh "golang.org/x/crypto/ssh"
3 )
4
5-func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
6- cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
7- cmd.SetOutput(sesh)
8- cmd.Usage = func() {
9- _, _ = fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
10- cmd.PrintDefaults()
11- }
12- return cmd
13-}
14-
15-func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
16- err := cmd.Parse(cmdArgs)
17-
18- if err != nil || posArg == "help" {
19- if posArg == "help" {
20- cmd.Usage()
21- }
22- return false
23- }
24- return true
25-}
26-
27-func NewTabWriter(out io.Writer) *tabwriter.Writer {
28- return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
29-}
30-
31-// scope topic to user by prefixing name.
32-func toTopic(userName, topic string) string {
33- if strings.HasPrefix(topic, userName+"/") {
34- return topic
35- }
36- return fmt.Sprintf("%s/%s", userName, topic)
37-}
38-
39-func toPublicTopic(topic string) string {
40- if strings.HasPrefix(topic, "public/") {
41- return topic
42- }
43- return fmt.Sprintf("public/%s", topic)
44-}
45-
46-func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
47- if len(clients) == 0 {
48- return ""
49- }
50-
51- outputData := fmt.Sprintf(" %s:\r\n", clientType)
52-
53- for _, client := range clients {
54- if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
55- continue
56- }
57-
58- outputData += fmt.Sprintf(" - %s\r\n", client.ID)
59- }
60-
61- return outputData
62-}
63-
64-var helpStr = func(sshCmd string) string {
65- data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
66-
67-The simplest authenticated pubsub system. Send messages through
68-user-defined topics. Topics are private to the authenticated
69-ssh user. The default pubsub model is multicast with bidirectional
70-blocking, meaning a publisher ("pub") will send its message to all
71-subscribers ("sub"). Further, both "pub" and "sub" will wait for
72-at least one event to be sent or received. Pipe ("pipe") allows
73-for bidirectional messages to be sent between any clients connected
74-to a pipe.
75-
76-Think of these different commands in terms of the direction the
77-data is being sent:
78-
79-- pub => writes to client
80-- sub => reads from client
81-- pipe => read and write between clients
82-`, sshCmd)
83-
84- data = strings.ReplaceAll(data, "\n", "\r\n")
85-
86- return data
87-}
88-
89-type CliHandler struct {
90- DBPool db.DB
91- Logger *slog.Logger
92- PubSub psub.PubSub
93- Cfg *shared.ConfigSite
94- Waiters *syncmap.Map[string, []string]
95- Access *syncmap.Map[string, []string]
96-}
97-
98-func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
99- return h.Logger
100-}
101-
102-func toSshCmd(cfg *shared.ConfigSite) string {
103- port := ""
104- if cfg.PortOverride != "22" {
105- port = fmt.Sprintf("-p %s ", cfg.PortOverride)
106- }
107- return fmt.Sprintf("%s%s", port, cfg.Domain)
108-}
109-
110-// parseArgList parses a comma separated list of arguments.
111-func parseArgList(arg string) []string {
112- argList := strings.Split(arg, ",")
113- for i, acc := range argList {
114- argList[i] = strings.TrimSpace(acc)
115- }
116- return argList
117-}
118-
119-// checkAccess checks if the user has access to a topic based on an access list.
120-func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
121- for _, acc := range accessList {
122- if acc == userName {
123- return true
124- }
125-
126- if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
127- return true
128- }
129- }
130-
131- return false
132-}
133-
134 func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
135- pubsub := handler.PubSub
136-
137 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
138 return func(sesh *pssh.SSHServerConnSession) error {
139 ctx := sesh.Context()
140@@ -160,22 +29,19 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
141 user := pssh.GetUser(sesh)
142
143 args := sesh.Command()
144-
145 if len(args) == 0 {
146- _, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
147+ help(handler.Cfg, sesh)
148 return next(sesh)
149 }
150
151 userName := "public"
152-
153 userNameAddition := ""
154-
155+ uuidStr := uuid.NewString()
156 isAdmin := false
157- impersonate := false
158 if user != nil {
159 isAdmin = handler.DBPool.HasFeatureByUser(user.ID, "admin")
160 if isAdmin && strings.HasPrefix(sesh.User(), "admin__") {
161- impersonate = true
162+ uuidStr = fmt.Sprintf("admin-%s", uuidStr)
163 }
164
165 userName = user.Name
166@@ -205,97 +71,25 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
167 }
168 }()
169
170+ cliCmd := &CliCmd{
171+ sesh: sesh,
172+ args: args,
173+ userName: userName,
174+ isAdmin: isAdmin,
175+ pipeCtx: pipeCtx,
176+ cancel: cancel,
177+ }
178+
179 cmd := strings.TrimSpace(args[0])
180- if cmd == "help" {
181- _, _ = fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
182+ switch cmd {
183+ case "help":
184+ help(handler.Cfg, sesh)
185 return next(sesh)
186- } else if cmd == "ls" {
187- if userName == "public" {
188- err := fmt.Errorf("access denied")
189+ case "ls":
190+ err := handler.ls(cliCmd)
191+ if err != nil {
192 sesh.Fatal(err)
193- return err
194- }
195-
196- topicFilter := fmt.Sprintf("%s/", userName)
197- if isAdmin {
198- topicFilter = ""
199- if len(args) > 1 {
200- topicFilter = args[1]
201- }
202- }
203-
204- var channels []*psub.Channel
205- waitingChannels := map[string][]string{}
206-
207- for topic, channel := range pubsub.GetChannels() {
208- if strings.HasPrefix(topic, topicFilter) {
209- channels = append(channels, channel)
210- }
211 }
212-
213- for channel, clients := range handler.Waiters.Range {
214- if strings.HasPrefix(channel, topicFilter) {
215- waitingChannels[channel] = clients
216- }
217- }
218-
219- if len(channels) == 0 && len(waitingChannels) == 0 {
220- _, _ = fmt.Fprintln(sesh, "no pubsub channels found")
221- } else {
222- var outputData string
223- if len(channels) > 0 || len(waitingChannels) > 0 {
224- outputData += "Channel Information\r\n"
225- for _, channel := range channels {
226- extraData := ""
227-
228- if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
229- extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
230- }
231-
232- outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
233- outputData += " Clients:\r\n"
234-
235- var pubs []*psub.Client
236- var subs []*psub.Client
237- var pipes []*psub.Client
238-
239- for _, client := range channel.GetClients() {
240- switch client.Direction {
241- case psub.ChannelDirectionInput:
242- pubs = append(pubs, client)
243- case psub.ChannelDirectionOutput:
244- subs = append(subs, client)
245- case psub.ChannelDirectionInputOutput:
246- pipes = append(pipes, client)
247- }
248- }
249- outputData += clientInfo(pubs, isAdmin, "Pubs")
250- outputData += clientInfo(subs, isAdmin, "Subs")
251- outputData += clientInfo(pipes, isAdmin, "Pipes")
252- }
253-
254- for waitingChannel, channelPubs := range waitingChannels {
255- extraData := ""
256-
257- if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
258- extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
259- }
260-
261- outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
262- outputData += " Clients:\r\n"
263- outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
264- for _, client := range channelPubs {
265- if strings.HasPrefix(client, "admin-") && !isAdmin {
266- continue
267- }
268- outputData += fmt.Sprintf(" - %s\r\n", client)
269- }
270- }
271- }
272-
273- _, _ = sesh.Write([]byte(outputData))
274- }
275-
276 return next(sesh)
277 }
278
279@@ -305,6 +99,8 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
280 topic = strings.TrimSpace(args[1])
281 cmdArgs = args[2:]
282 }
283+ // sub commands after this line expect clipped args
284+ cliCmd.args = cmdArgs
285
286 logger.Info(
287 "pubsub middleware detected command",
288@@ -314,421 +110,668 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
289 "cmdArgs", cmdArgs,
290 )
291
292- uuidStr := uuid.NewString()
293- if impersonate {
294- uuidStr = fmt.Sprintf("admin-%s", uuidStr)
295+ clientID := fmt.Sprintf(
296+ "%s (%s%s@%s)",
297+ uuidStr,
298+ userName,
299+ userNameAddition,
300+ sesh.RemoteAddr().String(),
301+ )
302+
303+ switch cmd {
304+ case "pub":
305+ err := handler.pub(cliCmd, topic, clientID)
306+ if err != nil {
307+ sesh.Fatal(err)
308+ }
309+ case "sub":
310+ err := handler.sub(cliCmd, topic, clientID)
311+ if err != nil {
312+ sesh.Fatal(err)
313+ }
314+ case "pipe":
315+ err := handler.pipe(cliCmd, topic, clientID)
316+ if err != nil {
317+ sesh.Fatal(err)
318+ }
319 }
320
321- clientID := fmt.Sprintf("%s (%s%s@%s)", uuidStr, userName, userNameAddition, sesh.RemoteAddr().String())
322+ return next(sesh)
323+ }
324+ }
325+}
326
327- var err error
328+type CliHandler struct {
329+ DBPool db.DB
330+ Logger *slog.Logger
331+ PubSub psub.PubSub
332+ Cfg *shared.ConfigSite
333+ Waiters *syncmap.Map[string, []string]
334+ Access *syncmap.Map[string, []string]
335+}
336
337- if cmd == "pub" {
338- pubCmd := flagSet("pub", sesh)
339- access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
340- empty := pubCmd.Bool("e", false, "Send an empty message to subs")
341- public := pubCmd.Bool("p", false, "Publish message to public topic")
342- block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
343- timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
344- clean := pubCmd.Bool("c", false, "Don't send status messages")
345+func (h *CliHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
346+ return h.Logger
347+}
348
349- if !flagCheck(pubCmd, topic, cmdArgs) {
350- return err
351- }
352+type CliCmd struct {
353+ sesh *pssh.SSHServerConnSession
354+ args []string
355+ userName string
356+ isAdmin bool
357+ pipeCtx context.Context
358+ cancel context.CancelFunc
359+}
360
361- if pubCmd.NArg() == 1 && topic == "" {
362- topic = pubCmd.Arg(0)
363- }
364+func help(cfg *shared.ConfigSite, sesh *pssh.SSHServerConnSession) {
365+ data := fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
366
367- logger.Info(
368- "flags parsed",
369- "cmd", cmd,
370- "empty", *empty,
371- "public", *public,
372- "block", *block,
373- "timeout", *timeout,
374- "topic", topic,
375- "access", *access,
376- "clean", *clean,
377- )
378-
379- var accessList []string
380-
381- if *access != "" {
382- accessList = parseArgList(*access)
383- }
384+The simplest authenticated pubsub system. Send messages through
385+user-defined topics. Topics are private to the authenticated
386+ssh user. The default pubsub model is multicast with bidirectional
387+blocking, meaning a publisher ("pub") will send its message to all
388+subscribers ("sub"). Further, both "pub" and "sub" will wait for
389+at least one event to be sent or received. Pipe ("pipe") allows
390+for bidirectional messages to be sent between any clients connected
391+to a pipe.
392
393- var rw io.ReadWriter
394- if *empty {
395- rw = bytes.NewBuffer(make([]byte, 1))
396- } else {
397- rw = sesh
398- }
399+Think of these different commands in terms of the direction the
400+data is being sent:
401
402- if topic == "" {
403- topic = uuid.NewString()
404- }
405+- pub => writes to client
406+- sub => reads from client
407+- pipe => read and write between clients
408+`, toSshCmd(cfg))
409
410- var withoutUser string
411- var name string
412- msgFlag := ""
413+ data = strings.ReplaceAll(data, "\n", "\r\n")
414+ _, _ = fmt.Fprintln(sesh, data)
415+}
416
417- if isAdmin && strings.HasPrefix(topic, "/") {
418- name = strings.TrimPrefix(topic, "/")
419- } else {
420- name = toTopic(userName, topic)
421- if *public {
422- name = toPublicTopic(topic)
423- msgFlag = "-p "
424- withoutUser = name
425- } else {
426- withoutUser = topic
427- }
428- }
429+func (handler *CliHandler) ls(cmd *CliCmd) error {
430+ if cmd.userName == "public" {
431+ err := fmt.Errorf("access denied")
432+ return err
433+ }
434
435- var accessListCreator bool
436+ topicFilter := fmt.Sprintf("%s/", cmd.userName)
437+ if cmd.isAdmin {
438+ topicFilter = ""
439+ if len(cmd.args) > 1 {
440+ topicFilter = cmd.args[1]
441+ }
442+ }
443
444- _, loaded := handler.Access.LoadOrStore(name, accessList)
445- if !loaded {
446- defer func() {
447- handler.Access.Delete(name)
448- }()
449+ var channels []*psub.Channel
450+ waitingChannels := map[string][]string{}
451
452- accessListCreator = true
453- }
454+ for topic, channel := range handler.PubSub.GetChannels() {
455+ if strings.HasPrefix(topic, topicFilter) {
456+ channels = append(channels, channel)
457+ }
458+ }
459
460- if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
461- if checkAccess(accessList, userName, sesh) || accessListCreator {
462- name = withoutUser
463- } else if !*public {
464- name = toTopic(userName, withoutUser)
465- } else {
466- topic = uuid.NewString()
467- name = toPublicTopic(topic)
468+ for channel, clients := range handler.Waiters.Range {
469+ if strings.HasPrefix(channel, topicFilter) {
470+ waitingChannels[channel] = clients
471+ }
472+ }
473+
474+ if len(channels) == 0 && len(waitingChannels) == 0 {
475+ _, _ = fmt.Fprintln(cmd.sesh, "no pubsub channels found")
476+ } else {
477+ var outputData string
478+ if len(channels) > 0 || len(waitingChannels) > 0 {
479+ outputData += "Channel Information\r\n"
480+ for _, channel := range channels {
481+ extraData := ""
482+
483+ if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
484+ extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
485+ }
486+
487+ outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
488+
489+ var pubs []*psub.Client
490+ var subs []*psub.Client
491+ var pipes []*psub.Client
492+
493+ for _, client := range channel.GetClients() {
494+ switch client.Direction {
495+ case psub.ChannelDirectionInput:
496+ pubs = append(pubs, client)
497+ case psub.ChannelDirectionOutput:
498+ subs = append(subs, client)
499+ case psub.ChannelDirectionInputOutput:
500+ pipes = append(pipes, client)
501 }
502 }
503+ outputData += clientInfo(pubs, cmd.isAdmin, "Pubs")
504+ outputData += clientInfo(subs, cmd.isAdmin, "Subs")
505+ outputData += clientInfo(pipes, cmd.isAdmin, "Pipes")
506+ }
507
508- if !*clean {
509- fmtTopic := topic
510- if *access != "" {
511- fmtTopic = fmt.Sprintf("%s/%s", userName, topic)
512- }
513+ for waitingChannel, channelPubs := range waitingChannels {
514+ extraData := ""
515
516- _, _ = fmt.Fprintf(
517- sesh,
518- "subscribe to this channel:\n ssh %s sub %s%s\n",
519- toSshCmd(handler.Cfg),
520- msgFlag,
521- fmtTopic,
522- )
523+ if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
524+ extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
525 }
526
527- if *block {
528- count := 0
529- for topic, channel := range pubsub.GetChannels() {
530- if topic == name {
531- for _, client := range channel.GetClients() {
532- if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
533- count++
534- }
535- }
536- break
537- }
538+ outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
539+ outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
540+ for _, client := range channelPubs {
541+ if strings.HasPrefix(client, "admin-") && !cmd.isAdmin {
542+ continue
543 }
544+ outputData += fmt.Sprintf(" - %s\r\n", client)
545+ }
546+ }
547+ }
548
549- tt := *timeout
550- if count == 0 {
551- currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
552- handler.Waiters.Store(name, append(currentWaiters, clientID))
553+ _, _ = cmd.sesh.Write([]byte(outputData))
554+ }
555
556- termMsg := "no subs found ... waiting"
557- if tt > 0 {
558- termMsg += " " + tt.String()
559- }
560+ return nil
561+}
562
563- if !*clean {
564- _, _ = fmt.Fprintln(sesh, termMsg)
565- }
566+func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error {
567+ pubCmd := flagSet("pub", cmd.sesh)
568+ access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
569+ empty := pubCmd.Bool("e", false, "Send an empty message to subs")
570+ public := pubCmd.Bool("p", false, "Publish message to public topic")
571+ block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
572+ timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
573+ clean := pubCmd.Bool("c", false, "Don't send status messages")
574+
575+ if !flagCheck(pubCmd, topic, cmd.args) {
576+ return fmt.Errorf("invalid cmd args")
577+ }
578
579- ready := make(chan struct{})
580-
581- go func() {
582- for {
583- select {
584- case <-pipeCtx.Done():
585- cancel()
586- return
587- case <-time.After(1 * time.Millisecond):
588- count := 0
589- for topic, channel := range pubsub.GetChannels() {
590- if topic == name {
591- for _, client := range channel.GetClients() {
592- if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
593- count++
594- }
595- }
596- break
597- }
598- }
599+ if pubCmd.NArg() == 1 && topic == "" {
600+ topic = pubCmd.Arg(0)
601+ }
602
603- if count > 0 {
604- close(ready)
605- return
606- }
607- }
608- }
609- }()
610-
611- select {
612- case <-ready:
613- case <-pipeCtx.Done():
614- case <-time.After(tt):
615- cancel()
616-
617- if !*clean {
618- sesh.Fatal(fmt.Errorf("timeout reached, exiting"))
619- } else {
620- err = sesh.Exit(1)
621- if err != nil {
622- logger.Error("error exiting session", "err", err)
623- }
624+ handler.Logger.Info(
625+ "flags parsed",
626+ "cmd", "pub",
627+ "empty", *empty,
628+ "public", *public,
629+ "block", *block,
630+ "timeout", *timeout,
631+ "topic", topic,
632+ "access", *access,
633+ "clean", *clean,
634+ )
635+
636+ var accessList []string
637+
638+ if *access != "" {
639+ accessList = parseArgList(*access)
640+ }
641
642- _ = sesh.Close()
643- }
644- }
645+ var rw io.ReadWriter
646+ if *empty {
647+ rw = bytes.NewBuffer(make([]byte, 1))
648+ } else {
649+ rw = cmd.sesh
650+ }
651+
652+ if topic == "" {
653+ topic = uuid.NewString()
654+ }
655+
656+ var withoutUser string
657+ var name string
658+ msgFlag := ""
659+
660+ if cmd.isAdmin && strings.HasPrefix(topic, "/") {
661+ name = strings.TrimPrefix(topic, "/")
662+ } else {
663+ name = toTopic(cmd.userName, topic)
664+ if *public {
665+ name = toPublicTopic(topic)
666+ msgFlag = "-p "
667+ withoutUser = name
668+ } else {
669+ withoutUser = topic
670+ }
671+ }
672+
673+ var accessListCreator bool
674+ _, loaded := handler.Access.LoadOrStore(name, accessList)
675+ if !loaded {
676+ defer func() {
677+ handler.Access.Delete(name)
678+ }()
679+
680+ accessListCreator = true
681+ }
682
683- newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
684- newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
685- return cl == clientID
686- })
687- handler.Waiters.Store(name, newWaiters)
688+ if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !cmd.isAdmin {
689+ if checkAccess(accessList, cmd.userName, cmd.sesh) || accessListCreator {
690+ name = withoutUser
691+ } else if !*public {
692+ name = toTopic(cmd.userName, withoutUser)
693+ } else {
694+ topic = uuid.NewString()
695+ name = toPublicTopic(topic)
696+ }
697+ }
698
699- var toDelete []string
700+ if !*clean {
701+ fmtTopic := topic
702+ if *access != "" {
703+ fmtTopic = fmt.Sprintf("%s/%s", cmd.userName, topic)
704+ }
705
706- for channel, clients := range handler.Waiters.Range {
707- if len(clients) == 0 {
708- toDelete = append(toDelete, channel)
709+ _, _ = fmt.Fprintf(
710+ cmd.sesh,
711+ "subscribe to this channel:\n ssh %s sub %s%s\n",
712+ toSshCmd(handler.Cfg),
713+ msgFlag,
714+ fmtTopic,
715+ )
716+ }
717+
718+ if *block {
719+ count := 0
720+ for topic, channel := range handler.PubSub.GetChannels() {
721+ if topic == name {
722+ for _, client := range channel.GetClients() {
723+ if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
724+ count++
725+ }
726+ }
727+ break
728+ }
729+ }
730+
731+ tt := *timeout
732+ if count == 0 {
733+ currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
734+ handler.Waiters.Store(name, append(currentWaiters, clientID))
735+
736+ termMsg := "no subs found ... waiting"
737+ if tt > 0 {
738+ termMsg += " " + tt.String()
739+ }
740+
741+ if !*clean {
742+ _, _ = fmt.Fprintln(cmd.sesh, termMsg)
743+ }
744+
745+ ready := make(chan struct{})
746+
747+ go func() {
748+ for {
749+ select {
750+ case <-cmd.pipeCtx.Done():
751+ cmd.cancel()
752+ return
753+ case <-time.After(1 * time.Millisecond):
754+ count := 0
755+ for topic, channel := range handler.PubSub.GetChannels() {
756+ if topic == name {
757+ for _, client := range channel.GetClients() {
758+ if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
759+ count++
760+ }
761+ }
762+ break
763 }
764 }
765
766- for _, channel := range toDelete {
767- handler.Waiters.Delete(channel)
768+ if count > 0 {
769+ close(ready)
770+ return
771 }
772 }
773 }
774+ }()
775+
776+ select {
777+ case <-ready:
778+ case <-cmd.pipeCtx.Done():
779+ case <-time.After(tt):
780+ cmd.cancel()
781
782 if !*clean {
783- _, _ = fmt.Fprintln(sesh, "sending msg ...")
784+ return fmt.Errorf("timeout reached, exiting")
785+ } else {
786+ err := cmd.sesh.Exit(1)
787+ if err != nil {
788+ handler.Logger.Error("error exiting session", "err", err)
789+ }
790+
791+ _ = cmd.sesh.Close()
792 }
793+ }
794
795- err = pubsub.Pub(
796- pipeCtx,
797- clientID,
798- rw,
799- []*psub.Channel{
800- psub.NewChannel(name),
801- },
802- *block,
803- )
804+ newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
805+ newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
806+ return cl == clientID
807+ })
808+ handler.Waiters.Store(name, newWaiters)
809
810- if !*clean {
811- _, _ = fmt.Fprintln(sesh, "msg sent!")
812- }
813+ var toDelete []string
814
815- if err != nil && !*clean {
816- _, _ = fmt.Fprintln(sesh.Stderr(), err)
817- }
818- } else if cmd == "sub" {
819- subCmd := flagSet("sub", sesh)
820- access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
821- public := subCmd.Bool("p", false, "Subscribe to a public topic")
822- keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
823- clean := subCmd.Bool("c", false, "Don't send status messages")
824-
825- if !flagCheck(subCmd, topic, cmdArgs) {
826- return err
827+ for channel, clients := range handler.Waiters.Range {
828+ if len(clients) == 0 {
829+ toDelete = append(toDelete, channel)
830 }
831+ }
832
833- if subCmd.NArg() == 1 && topic == "" {
834- topic = subCmd.Arg(0)
835- }
836+ for _, channel := range toDelete {
837+ handler.Waiters.Delete(channel)
838+ }
839+ }
840+ }
841
842- logger.Info(
843- "flags parsed",
844- "cmd", cmd,
845- "public", *public,
846- "keepAlive", *keepAlive,
847- "topic", topic,
848- "clean", *clean,
849- "access", *access,
850- )
851+ if !*clean {
852+ _, _ = fmt.Fprintln(cmd.sesh, "sending msg ...")
853+ }
854
855- var accessList []string
856+ err := handler.PubSub.Pub(
857+ cmd.pipeCtx,
858+ clientID,
859+ rw,
860+ []*psub.Channel{
861+ psub.NewChannel(name),
862+ },
863+ *block,
864+ )
865+
866+ if !*clean {
867+ _, _ = fmt.Fprintln(cmd.sesh, "msg sent!")
868+ }
869
870- if *access != "" {
871- accessList = parseArgList(*access)
872- }
873+ if err != nil && !*clean {
874+ return err
875+ }
876
877- var withoutUser string
878- var name string
879+ return nil
880+}
881
882- if isAdmin && strings.HasPrefix(topic, "/") {
883- name = strings.TrimPrefix(topic, "/")
884- } else {
885- name = toTopic(userName, topic)
886- if *public {
887- name = toPublicTopic(topic)
888- withoutUser = name
889- } else {
890- withoutUser = topic
891- }
892- }
893+func (handler *CliHandler) sub(cmd *CliCmd, topic string, clientID string) error {
894+ subCmd := flagSet("sub", cmd.sesh)
895+ access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
896+ public := subCmd.Bool("p", false, "Subscribe to a public topic")
897+ keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
898+ clean := subCmd.Bool("c", false, "Don't send status messages")
899
900- var accessListCreator bool
901+ if !flagCheck(subCmd, topic, cmd.args) {
902+ return fmt.Errorf("invalid cmd args")
903+ }
904
905- _, loaded := handler.Access.LoadOrStore(name, accessList)
906- if !loaded {
907- defer func() {
908- handler.Access.Delete(name)
909- }()
910+ if subCmd.NArg() == 1 && topic == "" {
911+ topic = subCmd.Arg(0)
912+ }
913
914- accessListCreator = true
915- }
916+ handler.Logger.Info(
917+ "flags parsed",
918+ "cmd", cmd,
919+ "public", *public,
920+ "keepAlive", *keepAlive,
921+ "topic", topic,
922+ "clean", *clean,
923+ "access", *access,
924+ )
925
926- if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
927- if checkAccess(accessList, userName, sesh) || accessListCreator {
928- name = withoutUser
929- } else if !*public {
930- name = toTopic(userName, withoutUser)
931- } else {
932- _, _ = fmt.Fprintln(sesh.Stderr(), "access denied")
933- return err
934- }
935- }
936+ var accessList []string
937
938- err = pubsub.Sub(
939- pipeCtx,
940- clientID,
941- sesh,
942- []*psub.Channel{
943- psub.NewChannel(name),
944- },
945- *keepAlive,
946- )
947-
948- if err != nil && !*clean {
949- _, _ = fmt.Fprintln(sesh.Stderr(), err)
950- }
951- } else if cmd == "pipe" {
952- pipeCmd := flagSet("pipe", sesh)
953- access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
954- public := pipeCmd.Bool("p", false, "Pipe to a public topic")
955- replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
956- clean := pipeCmd.Bool("c", false, "Don't send status messages")
957-
958- if !flagCheck(pipeCmd, topic, cmdArgs) {
959- return err
960- }
961+ if *access != "" {
962+ accessList = parseArgList(*access)
963+ }
964
965- if pipeCmd.NArg() == 1 && topic == "" {
966- topic = pipeCmd.Arg(0)
967- }
968+ var withoutUser string
969+ var name string
970+
971+ if cmd.isAdmin && strings.HasPrefix(topic, "/") {
972+ name = strings.TrimPrefix(topic, "/")
973+ } else {
974+ name = toTopic(cmd.userName, topic)
975+ if *public {
976+ name = toPublicTopic(topic)
977+ withoutUser = name
978+ } else {
979+ withoutUser = topic
980+ }
981+ }
982
983- logger.Info(
984- "flags parsed",
985- "cmd", cmd,
986- "public", *public,
987- "replay", *replay,
988- "topic", topic,
989- "access", *access,
990- "clean", *clean,
991- )
992+ var accessListCreator bool
993
994- var accessList []string
995+ _, loaded := handler.Access.LoadOrStore(name, accessList)
996+ if !loaded {
997+ defer func() {
998+ handler.Access.Delete(name)
999+ }()
1000+ accessListCreator = true
1001+ }
1002
1003- if *access != "" {
1004- accessList = parseArgList(*access)
1005- }
1006+ if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !cmd.isAdmin {
1007+ if checkAccess(accessList, cmd.userName, cmd.sesh) || accessListCreator {
1008+ name = withoutUser
1009+ } else if !*public {
1010+ name = toTopic(cmd.userName, withoutUser)
1011+ } else {
1012+ return fmt.Errorf("access denied")
1013+ }
1014+ }
1015
1016- isCreator := topic == ""
1017- if isCreator {
1018- topic = uuid.NewString()
1019- }
1020+ err := handler.PubSub.Sub(
1021+ cmd.pipeCtx,
1022+ clientID,
1023+ cmd.sesh,
1024+ []*psub.Channel{
1025+ psub.NewChannel(name),
1026+ },
1027+ *keepAlive,
1028+ )
1029+
1030+ if err != nil && !*clean {
1031+ return err
1032+ }
1033
1034- var withoutUser string
1035- var name string
1036- flagMsg := ""
1037+ return nil
1038+}
1039
1040- if isAdmin && strings.HasPrefix(topic, "/") {
1041- name = strings.TrimPrefix(topic, "/")
1042- } else {
1043- name = toTopic(userName, topic)
1044- if *public {
1045- name = toPublicTopic(topic)
1046- flagMsg = "-p "
1047- withoutUser = name
1048- } else {
1049- withoutUser = topic
1050- }
1051- }
1052+func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) error {
1053+ pipeCmd := flagSet("pipe", cmd.sesh)
1054+ access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
1055+ public := pipeCmd.Bool("p", false, "Pipe to a public topic")
1056+ replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
1057+ clean := pipeCmd.Bool("c", false, "Don't send status messages")
1058
1059- var accessListCreator bool
1060+ if !flagCheck(pipeCmd, topic, cmd.args) {
1061+ return fmt.Errorf("invalid cmd args")
1062+ }
1063
1064- _, loaded := handler.Access.LoadOrStore(name, accessList)
1065- if !loaded {
1066- defer func() {
1067- handler.Access.Delete(name)
1068- }()
1069+ if pipeCmd.NArg() == 1 && topic == "" {
1070+ topic = pipeCmd.Arg(0)
1071+ }
1072
1073- accessListCreator = true
1074- }
1075+ handler.Logger.Info(
1076+ "flags parsed",
1077+ "cmd", cmd,
1078+ "public", *public,
1079+ "replay", *replay,
1080+ "topic", topic,
1081+ "access", *access,
1082+ "clean", *clean,
1083+ )
1084
1085- if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
1086- if checkAccess(accessList, userName, sesh) || accessListCreator {
1087- name = withoutUser
1088- } else if !*public {
1089- name = toTopic(userName, withoutUser)
1090- } else {
1091- topic = uuid.NewString()
1092- name = toPublicTopic(topic)
1093- }
1094- }
1095+ var accessList []string
1096
1097- if isCreator && !*clean {
1098- fmtTopic := topic
1099- if *access != "" {
1100- fmtTopic = fmt.Sprintf("%s/%s", userName, topic)
1101- }
1102+ if *access != "" {
1103+ accessList = parseArgList(*access)
1104+ }
1105
1106- _, _ = fmt.Fprintf(
1107- sesh,
1108- "subscribe to this topic:\n ssh %s sub %s%s\n",
1109- toSshCmd(handler.Cfg),
1110- flagMsg,
1111- fmtTopic,
1112- )
1113- }
1114+ isCreator := topic == ""
1115+ if isCreator {
1116+ topic = uuid.NewString()
1117+ }
1118
1119- readErr, writeErr := pubsub.Pipe(
1120- pipeCtx,
1121- clientID,
1122- sesh,
1123- []*psub.Channel{
1124- psub.NewChannel(name),
1125- },
1126- *replay,
1127- )
1128-
1129- if readErr != nil && !*clean {
1130- _, _ = fmt.Fprintln(sesh.Stderr(), "error reading from pipe", readErr)
1131- }
1132+ var withoutUser string
1133+ var name string
1134+ flagMsg := ""
1135+
1136+ if cmd.isAdmin && strings.HasPrefix(topic, "/") {
1137+ name = strings.TrimPrefix(topic, "/")
1138+ } else {
1139+ name = toTopic(cmd.userName, topic)
1140+ if *public {
1141+ name = toPublicTopic(topic)
1142+ flagMsg = "-p "
1143+ withoutUser = name
1144+ } else {
1145+ withoutUser = topic
1146+ }
1147+ }
1148
1149- if writeErr != nil && !*clean {
1150- _, _ = fmt.Fprintln(sesh.Stderr(), "error writing to pipe", writeErr)
1151- }
1152- }
1153+ var accessListCreator bool
1154
1155- return next(sesh)
1156+ _, loaded := handler.Access.LoadOrStore(name, accessList)
1157+ if !loaded {
1158+ defer func() {
1159+ handler.Access.Delete(name)
1160+ }()
1161+ accessListCreator = true
1162+ }
1163+
1164+ if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !cmd.isAdmin {
1165+ if checkAccess(accessList, cmd.userName, cmd.sesh) || accessListCreator {
1166+ name = withoutUser
1167+ } else if !*public {
1168+ name = toTopic(cmd.userName, withoutUser)
1169+ } else {
1170+ topic = uuid.NewString()
1171+ name = toPublicTopic(topic)
1172+ }
1173+ }
1174+
1175+ if isCreator && !*clean {
1176+ fmtTopic := topic
1177+ if *access != "" {
1178+ fmtTopic = fmt.Sprintf("%s/%s", cmd.userName, topic)
1179+ }
1180+
1181+ _, _ = fmt.Fprintf(
1182+ cmd.sesh,
1183+ "subscribe to this topic:\n ssh %s sub %s%s\n",
1184+ toSshCmd(handler.Cfg),
1185+ flagMsg,
1186+ fmtTopic,
1187+ )
1188+ }
1189+
1190+ readErr, writeErr := handler.PubSub.Pipe(
1191+ cmd.pipeCtx,
1192+ clientID,
1193+ cmd.sesh,
1194+ []*psub.Channel{
1195+ psub.NewChannel(name),
1196+ },
1197+ *replay,
1198+ )
1199+
1200+ if readErr != nil && !*clean {
1201+ return readErr
1202+ }
1203+
1204+ if writeErr != nil && !*clean {
1205+ return writeErr
1206+ }
1207+
1208+ return nil
1209+}
1210+
1211+func toSshCmd(cfg *shared.ConfigSite) string {
1212+ port := ""
1213+ if cfg.PortOverride != "22" {
1214+ port = fmt.Sprintf("-p %s ", cfg.PortOverride)
1215+ }
1216+ return fmt.Sprintf("%s%s", port, cfg.Domain)
1217+}
1218+
1219+// parseArgList parses a comma separated list of arguments.
1220+func parseArgList(arg string) []string {
1221+ argList := strings.Split(arg, ",")
1222+ for i, acc := range argList {
1223+ argList[i] = strings.TrimSpace(acc)
1224+ }
1225+ return argList
1226+}
1227+
1228+// checkAccess checks if the user has access to a topic based on an access list.
1229+func checkAccess(accessList []string, userName string, sesh *pssh.SSHServerConnSession) bool {
1230+ for _, acc := range accessList {
1231+ if acc == userName {
1232+ return true
1233+ }
1234+
1235+ if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
1236+ return true
1237 }
1238 }
1239+
1240+ return false
1241+}
1242+
1243+func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) *flag.FlagSet {
1244+ cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
1245+ cmd.SetOutput(sesh)
1246+ cmd.Usage = func() {
1247+ _, _ = fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
1248+ cmd.PrintDefaults()
1249+ }
1250+ return cmd
1251+}
1252+
1253+func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
1254+ err := cmd.Parse(cmdArgs)
1255+
1256+ if err != nil || posArg == "help" {
1257+ if posArg == "help" {
1258+ cmd.Usage()
1259+ }
1260+ return false
1261+ }
1262+ return true
1263+}
1264+
1265+func NewTabWriter(out io.Writer) *tabwriter.Writer {
1266+ return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
1267+}
1268+
1269+// scope topic to user by prefixing name.
1270+func toTopic(userName, topic string) string {
1271+ if strings.HasPrefix(topic, userName+"/") {
1272+ return topic
1273+ }
1274+ return fmt.Sprintf("%s/%s", userName, topic)
1275+}
1276+
1277+func toPublicTopic(topic string) string {
1278+ if strings.HasPrefix(topic, "public/") {
1279+ return topic
1280+ }
1281+ return fmt.Sprintf("public/%s", topic)
1282+}
1283+
1284+func clientInfo(clients []*psub.Client, isAdmin bool, clientType string) string {
1285+ if len(clients) == 0 {
1286+ return ""
1287+ }
1288+
1289+ outputData := fmt.Sprintf(" %s:\r\n", clientType)
1290+
1291+ for _, client := range clients {
1292+ if strings.HasPrefix(client.ID, "admin-") && !isAdmin {
1293+ continue
1294+ }
1295+
1296+ outputData += fmt.Sprintf(" - %s\r\n", client.ID)
1297+ }
1298+
1299+ return outputData
1300 }