Eric Bower
·
2025-06-28
cli.go
1package pico
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10 "time"
11
12 "github.com/picosh/pico/pkg/db"
13 "github.com/picosh/pico/pkg/pssh"
14 "github.com/picosh/pico/pkg/shared"
15 "github.com/picosh/utils"
16
17 pipeLogger "github.com/picosh/utils/pipe/log"
18)
19
20func getUser(s *pssh.SSHServerConnSession, dbpool db.DB) (*db.User, error) {
21 if s.PublicKey() == nil {
22 return nil, fmt.Errorf("key not found")
23 }
24
25 key := utils.KeyForKeyText(s.PublicKey())
26
27 user, err := dbpool.FindUserForKey(s.User(), key)
28 if err != nil {
29 return nil, err
30 }
31
32 if user.Name == "" {
33 return nil, fmt.Errorf("must have username set")
34 }
35
36 return user, nil
37}
38
39type Cmd struct {
40 User *db.User
41 SshSession *pssh.SSHServerConnSession
42 Session utils.CmdSession
43 Log *slog.Logger
44 Dbpool db.DB
45 Write bool
46}
47
48func (c *Cmd) output(out string) {
49 _, _ = c.Session.Write([]byte(out + "\r\n"))
50}
51
52func (c *Cmd) help() {
53 helpStr := "Commands: [help, user, logs, chat, not-found]\n"
54 helpStr += "help - this message\n"
55 helpStr += "user - display user information (returns name, id, account created, pico+ expiration)\n"
56 helpStr += "logs - stream user logs\n"
57 helpStr += "chat - IRC chat (must enable pty with `-t` to the SSH command)\n"
58 helpStr += "not-found - return all status 404 requests for a host (hostname.com [year|month])\n"
59 c.output(helpStr)
60}
61
62func (c *Cmd) user() {
63 plus := ""
64 ff, _ := c.Dbpool.FindFeature(c.User.ID, "plus")
65 if ff != nil {
66 plus = ff.ExpiresAt.Format(time.RFC3339)
67 }
68 helpStr := fmt.Sprintf(
69 "%s\n%s\n%s\n%s",
70 c.User.Name,
71 c.User.ID,
72 c.User.CreatedAt.Format(time.RFC3339),
73 plus,
74 )
75 c.output(helpStr)
76}
77
78func (c *Cmd) notFound(host, interval string) error {
79 origin := utils.StartOfYear()
80 if interval == "month" {
81 origin = utils.StartOfMonth()
82 }
83 c.output(fmt.Sprintf("starting from: %s\n", origin.Format(time.RFC3339)))
84 urls, err := c.Dbpool.VisitUrlNotFound(&db.SummaryOpts{
85 Host: host,
86 UserID: c.User.ID,
87 Limit: 100,
88 Origin: origin,
89 })
90 if err != nil {
91 return err
92 }
93 for _, url := range urls {
94 c.output(fmt.Sprintf("%d %s", url.Count, url.Url))
95 }
96 return nil
97}
98
99func (c *Cmd) logs(ctx context.Context) error {
100 conn := shared.NewPicoPipeClient()
101 stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
102
103 if err != nil {
104 return err
105 }
106
107 logChan := make(chan string)
108 defer close(logChan)
109
110 go func() {
111 for {
112 select {
113 case <-ctx.Done():
114 return
115 case log, ok := <-logChan:
116 if log == "" {
117 continue
118 }
119 if !ok {
120 return
121 }
122 _, _ = fmt.Fprintln(c.SshSession, log)
123 }
124 }
125 }()
126
127 scanner := bufio.NewScanner(stdoutPipe)
128 scanner.Buffer(make([]byte, 32*1024), 32*1024)
129 for scanner.Scan() {
130 line := scanner.Text()
131 parsedData := map[string]any{}
132
133 err := json.Unmarshal([]byte(line), &parsedData)
134 if err != nil {
135 c.Log.Error("json unmarshal", "err", err, "line", line, "hidden", true)
136 continue
137 }
138
139 user := utils.AnyToStr(parsedData, "user")
140 userId := utils.AnyToStr(parsedData, "userId")
141
142 hidden := utils.AnyToBool(parsedData, "hidden")
143
144 if !hidden && (user == c.User.Name || userId == c.User.ID) {
145 select {
146 case logChan <- line:
147 case <-ctx.Done():
148 return nil
149 default:
150 c.Log.Error("logChan is full, dropping log", "log", line)
151 continue
152 }
153 }
154 }
155 return scanner.Err()
156}
157
158type CliHandler struct {
159 DBPool db.DB
160 Logger *slog.Logger
161}
162
163func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
164 dbpool := handler.DBPool
165 log := handler.Logger
166
167 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
168 return func(sesh *pssh.SSHServerConnSession) error {
169 args := sesh.Command()
170 if len(args) == 0 {
171 return next(sesh)
172 }
173
174 user, err := getUser(sesh, dbpool)
175 if err != nil {
176 _, _ = fmt.Fprintf(sesh.Stderr(), "detected ssh command: %s\n", args)
177 s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
178 sesh.Fatal(s)
179 return s
180 }
181
182 if len(args) > 0 && args[0] == "chat" {
183 _, _, hasPty := sesh.Pty()
184 if !hasPty {
185 err := fmt.Errorf(
186 "in order to render chat you need to enable PTY with the `ssh -t` flag",
187 )
188
189 sesh.Fatal(err)
190 return err
191 }
192
193 ff, err := dbpool.FindFeature(user.ID, "plus")
194 if err != nil {
195 handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
196 ff, err = dbpool.FindFeature(user.ID, "bouncer")
197 if err != nil {
198 handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
199 sesh.Fatal(err)
200 return err
201 }
202 }
203
204 if ff == nil {
205 err = fmt.Errorf("unable to find plus or bouncer feature flag")
206 sesh.Fatal(err)
207 return err
208 }
209
210 pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
211 if err != nil {
212 sesh.Fatal(err)
213 return err
214 }
215 app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
216 if err != nil {
217 sesh.Fatal(err)
218 return err
219 }
220 app.Run()
221 app.Close()
222 return err
223 }
224
225 opts := Cmd{
226 Session: sesh,
227 SshSession: sesh,
228 User: user,
229 Log: log,
230 Dbpool: dbpool,
231 Write: false,
232 }
233
234 cmd := strings.TrimSpace(args[0])
235 if len(args) == 1 {
236 switch cmd {
237 case "help":
238 opts.help()
239 return nil
240 case "user":
241 opts.user()
242 return nil
243 case "logs":
244 err = opts.logs(sesh.Context())
245 if err != nil {
246 sesh.Fatal(err)
247 }
248 return nil
249 default:
250 return next(sesh)
251 }
252 }
253
254 if cmd == "not-found" {
255 if len(args) < 3 {
256 sesh.Fatal(fmt.Errorf("must provide host name and interval (`month` or `year`)"))
257 return nil
258 }
259 err = opts.notFound(args[1], args[2])
260 if err != nil {
261 sesh.Fatal(err)
262 }
263 return nil
264 }
265
266 return next(sesh)
267 }
268 }
269}