Antonio Mika
·
2025-04-16
cli.go
1package pico
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/picosh/pico/pkg/db"
12 "github.com/picosh/pico/pkg/pssh"
13 "github.com/picosh/pico/pkg/shared"
14 "github.com/picosh/utils"
15
16 pipeLogger "github.com/picosh/utils/pipe/log"
17)
18
19func getUser(s *pssh.SSHServerConnSession, dbpool db.DB) (*db.User, error) {
20 if s.PublicKey() == nil {
21 return nil, fmt.Errorf("key not found")
22 }
23
24 key := utils.KeyForKeyText(s.PublicKey())
25
26 user, err := dbpool.FindUserForKey(s.User(), key)
27 if err != nil {
28 return nil, err
29 }
30
31 if user.Name == "" {
32 return nil, fmt.Errorf("must have username set")
33 }
34
35 return user, nil
36}
37
38type Cmd struct {
39 User *db.User
40 SshSession *pssh.SSHServerConnSession
41 Session utils.CmdSession
42 Log *slog.Logger
43 Dbpool db.DB
44 Write bool
45}
46
47func (c *Cmd) output(out string) {
48 _, _ = c.Session.Write([]byte(out + "\r\n"))
49}
50
51func (c *Cmd) help() {
52 helpStr := "Commands: [help, pico+]\n"
53 c.output(helpStr)
54}
55
56func (c *Cmd) logs(ctx context.Context) error {
57 conn := shared.NewPicoPipeClient()
58 stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
59
60 if err != nil {
61 return err
62 }
63
64 logChan := make(chan string)
65 defer close(logChan)
66
67 go func() {
68 for {
69 select {
70 case <-ctx.Done():
71 return
72 case log, ok := <-logChan:
73 if log == "" {
74 continue
75 }
76 if !ok {
77 return
78 }
79 fmt.Fprintln(c.SshSession, log)
80 }
81 }
82 }()
83
84 scanner := bufio.NewScanner(stdoutPipe)
85 scanner.Buffer(make([]byte, 32*1024), 32*1024)
86 for scanner.Scan() {
87 line := scanner.Text()
88 parsedData := map[string]any{}
89
90 err := json.Unmarshal([]byte(line), &parsedData)
91 if err != nil {
92 c.Log.Error("json unmarshal", "err", err)
93 continue
94 }
95
96 user := utils.AnyToStr(parsedData, "user")
97 userId := utils.AnyToStr(parsedData, "userId")
98 if user == c.User.Name || userId == c.User.ID {
99 select {
100 case logChan <- line:
101 case <-ctx.Done():
102 return nil
103 default:
104 c.Log.Error("logChan is full, dropping log", "log", line)
105 continue
106 }
107 }
108 }
109 return scanner.Err()
110}
111
112type CliHandler struct {
113 DBPool db.DB
114 Logger *slog.Logger
115}
116
117func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
118 dbpool := handler.DBPool
119 log := handler.Logger
120
121 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
122 return func(sesh *pssh.SSHServerConnSession) error {
123 args := sesh.Command()
124 if len(args) == 0 {
125 return next(sesh)
126 }
127
128 user, err := getUser(sesh, dbpool)
129 if err != nil {
130 fmt.Fprintf(sesh.Stderr(), "detected ssh command: %s\n", args)
131 s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
132 sesh.Fatal(s)
133 return s
134 }
135
136 if len(args) > 0 && args[0] == "chat" {
137 _, _, hasPty := sesh.Pty()
138 if !hasPty {
139 err := fmt.Errorf(
140 "in order to render chat you need to enable PTY with the `ssh -t` flag",
141 )
142
143 sesh.Fatal(err)
144 return err
145 }
146
147 ff, err := dbpool.FindFeature(user.ID, "plus")
148 if err != nil {
149 handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
150 ff, err = dbpool.FindFeature(user.ID, "bouncer")
151 if err != nil {
152 handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
153 sesh.Fatal(err)
154 return err
155 }
156 }
157
158 if ff == nil {
159 err = fmt.Errorf("unable to find plus or bouncer feature flag")
160 sesh.Fatal(err)
161 return err
162 }
163
164 pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
165 if err != nil {
166 sesh.Fatal(err)
167 return err
168 }
169 app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
170 if err != nil {
171 sesh.Fatal(err)
172 return err
173 }
174 app.Run()
175 app.Close()
176 return err
177 }
178
179 opts := Cmd{
180 Session: sesh,
181 SshSession: sesh,
182 User: user,
183 Log: log,
184 Dbpool: dbpool,
185 Write: false,
186 }
187
188 cmd := strings.TrimSpace(args[0])
189 if len(args) == 1 {
190 if cmd == "help" {
191 opts.help()
192 return nil
193 } else if cmd == "logs" {
194 err = opts.logs(sesh.Context())
195 if err != nil {
196 sesh.Fatal(err)
197 }
198 return nil
199 } else {
200 return next(sesh)
201 }
202 }
203
204 return next(sesh)
205 }
206 }
207}