Eric Bower
·
2026-01-25
cli.go
1package pico
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10 "time"
11
12 "github.com/picosh/pico/pkg/db"
13 "github.com/picosh/pico/pkg/pssh"
14 "github.com/picosh/pico/pkg/shared"
15
16 pipeLogger "github.com/picosh/utils/pipe/log"
17)
18
19func getUser(s *pssh.SSHServerConnSession, dbpool db.DB) (*db.User, error) {
20 if s.PublicKey() == nil {
21 return nil, fmt.Errorf("key not found")
22 }
23
24 key := shared.KeyForKeyText(s.PublicKey())
25
26 user, err := dbpool.FindUserByKey(s.User(), key)
27 if err != nil {
28 return nil, err
29 }
30
31 if user.Name == "" {
32 return nil, fmt.Errorf("must have username set")
33 }
34
35 return user, nil
36}
37
38type Cmd struct {
39 User *db.User
40 SshSession *pssh.SSHServerConnSession
41 Session shared.CmdSession
42 Log *slog.Logger
43 Dbpool db.DB
44 Write bool
45}
46
47func (c *Cmd) output(out string) {
48 _, _ = c.Session.Write([]byte(out + "\r\n"))
49}
50
51func (c *Cmd) help() {
52 helpStr := "Commands: [help, user, logs, access_logs, chat, not-found]\n"
53 helpStr += "help - this message\n"
54 helpStr += "user - display user information (returns name, id, account created, pico+ expiration)\n"
55 helpStr += "logs - stream user logs\n"
56 helpStr += "access_logs - fetch access logs from the last 30 days\n"
57 helpStr += "chat - IRC chat (must enable pty with `-t` to the SSH command)\n"
58 helpStr += "not-found - return all status 404 requests for a host (hostname.com [year|month])\n"
59 c.output(helpStr)
60}
61
62func (c *Cmd) user() {
63 plus := ""
64 ff, _ := c.Dbpool.FindFeature(c.User.ID, "plus")
65 if ff != nil {
66 plus = ff.ExpiresAt.Format(time.RFC3339)
67 }
68 helpStr := fmt.Sprintf(
69 "%s\n%s\n%s\n%s",
70 c.User.Name,
71 c.User.ID,
72 c.User.CreatedAt.Format(time.RFC3339),
73 plus,
74 )
75 c.output(helpStr)
76}
77
78func (c *Cmd) notFound(host, interval string) error {
79 origin := shared.StartOfYear()
80 if interval == "month" {
81 origin = shared.StartOfMonth()
82 }
83 c.output(fmt.Sprintf("starting from: %s\n", origin.Format(time.RFC3339)))
84 urls, err := c.Dbpool.VisitUrlNotFound(&db.SummaryOpts{
85 Host: host,
86 UserID: c.User.ID,
87 Limit: 100,
88 Origin: origin,
89 })
90 if err != nil {
91 return err
92 }
93 for _, url := range urls {
94 c.output(fmt.Sprintf("%d %s", url.Count, url.Url))
95 }
96 return nil
97}
98
99func (c *Cmd) access_logs(ctx context.Context) error {
100 fromDate := time.Now().AddDate(0, 0, -30)
101 logs, err := c.Dbpool.FindAccessLogs(c.User.ID, &fromDate)
102 if err != nil {
103 return err
104 }
105
106 for _, log := range logs {
107 jsonData, err := json.Marshal(log)
108 if err != nil {
109 c.Log.Error("json marshall", "err", err)
110 continue
111 }
112 c.output(string(jsonData))
113 }
114 return nil
115}
116
117func (c *Cmd) logs(ctx context.Context) error {
118 conn := shared.NewPicoPipeClient()
119 stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
120
121 if err != nil {
122 return err
123 }
124
125 logChan := make(chan string)
126 defer close(logChan)
127
128 go func() {
129 for {
130 select {
131 case <-ctx.Done():
132 return
133 case log, ok := <-logChan:
134 if log == "" {
135 continue
136 }
137 if !ok {
138 return
139 }
140 _, _ = fmt.Fprintln(c.SshSession, log)
141 }
142 }
143 }()
144
145 scanner := bufio.NewScanner(stdoutPipe)
146 scanner.Buffer(make([]byte, 32*1024), 32*1024)
147 for scanner.Scan() {
148 line := scanner.Text()
149 parsedData := map[string]any{}
150
151 err := json.Unmarshal([]byte(line), &parsedData)
152 if err != nil {
153 c.Log.Error("json unmarshal", "err", err, "line", line, "hidden", true)
154 continue
155 }
156
157 user := shared.AnyToStr(parsedData, "user")
158 userId := shared.AnyToStr(parsedData, "userId")
159
160 hidden := shared.AnyToBool(parsedData, "hidden")
161
162 if !hidden && (user == c.User.Name || userId == c.User.ID) {
163 select {
164 case logChan <- line:
165 case <-ctx.Done():
166 return nil
167 default:
168 c.Log.Error("logChan is full, dropping log", "log", line)
169 continue
170 }
171 }
172 }
173 return scanner.Err()
174}
175
176type CliHandler struct {
177 DBPool db.DB
178 Logger *slog.Logger
179}
180
181func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
182 dbpool := handler.DBPool
183 log := handler.Logger
184
185 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
186 return func(sesh *pssh.SSHServerConnSession) error {
187 args := sesh.Command()
188 if len(args) == 0 {
189 return next(sesh)
190 }
191
192 user, err := getUser(sesh, dbpool)
193 if err != nil {
194 _, _ = fmt.Fprintf(sesh.Stderr(), "detected ssh command: %s\n", args)
195 s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
196 sesh.Fatal(s)
197 return s
198 }
199
200 if len(args) > 0 && args[0] == "chat" {
201 _, _, hasPty := sesh.Pty()
202 if !hasPty {
203 err := fmt.Errorf(
204 "in order to render chat you need to enable PTY with the `ssh -t` flag",
205 )
206
207 sesh.Fatal(err)
208 return err
209 }
210
211 ff, err := dbpool.FindFeature(user.ID, "plus")
212 if err != nil {
213 handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
214 ff, err = dbpool.FindFeature(user.ID, "bouncer")
215 if err != nil {
216 handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
217 sesh.Fatal(err)
218 return err
219 }
220 }
221
222 if ff == nil {
223 err = fmt.Errorf("unable to find plus or bouncer feature flag")
224 sesh.Fatal(err)
225 return err
226 }
227
228 pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
229 if err != nil {
230 sesh.Fatal(err)
231 return err
232 }
233 app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
234 if err != nil {
235 sesh.Fatal(err)
236 return err
237 }
238 app.Run()
239 app.Close()
240 return err
241 }
242
243 opts := Cmd{
244 Session: sesh,
245 SshSession: sesh,
246 User: user,
247 Log: log,
248 Dbpool: dbpool,
249 Write: false,
250 }
251
252 cmd := strings.TrimSpace(args[0])
253 if len(args) == 1 {
254 switch cmd {
255 case "help":
256 opts.help()
257 return nil
258 case "user":
259 opts.user()
260 return nil
261 case "logs":
262 err = opts.logs(sesh.Context())
263 if err != nil {
264 sesh.Fatal(err)
265 }
266 return nil
267 case "access_logs":
268 err = opts.access_logs(sesh.Context())
269 if err != nil {
270 sesh.Fatal(err)
271 }
272 return nil
273 default:
274 return next(sesh)
275 }
276 }
277
278 if cmd == "not-found" {
279 if len(args) < 3 {
280 sesh.Fatal(fmt.Errorf("must provide host name and interval (`month` or `year`)"))
281 return nil
282 }
283 err = opts.notFound(args[1], args[2])
284 if err != nil {
285 sesh.Fatal(err)
286 }
287 return nil
288 }
289
290 return next(sesh)
291 }
292 }
293}