repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pico
Eric Bower  ·  2026-01-25

cli.go

  1package pico
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"time"
 11
 12	"github.com/picosh/pico/pkg/db"
 13	"github.com/picosh/pico/pkg/pssh"
 14	"github.com/picosh/pico/pkg/shared"
 15
 16	pipeLogger "github.com/picosh/utils/pipe/log"
 17)
 18
 19func getUser(s *pssh.SSHServerConnSession, dbpool db.DB) (*db.User, error) {
 20	if s.PublicKey() == nil {
 21		return nil, fmt.Errorf("key not found")
 22	}
 23
 24	key := shared.KeyForKeyText(s.PublicKey())
 25
 26	user, err := dbpool.FindUserByKey(s.User(), key)
 27	if err != nil {
 28		return nil, err
 29	}
 30
 31	if user.Name == "" {
 32		return nil, fmt.Errorf("must have username set")
 33	}
 34
 35	return user, nil
 36}
 37
 38type Cmd struct {
 39	User       *db.User
 40	SshSession *pssh.SSHServerConnSession
 41	Session    shared.CmdSession
 42	Log        *slog.Logger
 43	Dbpool     db.DB
 44	Write      bool
 45}
 46
 47func (c *Cmd) output(out string) {
 48	_, _ = c.Session.Write([]byte(out + "\r\n"))
 49}
 50
 51func (c *Cmd) help() {
 52	helpStr := "Commands: [help, user, logs, access_logs, chat, not-found]\n"
 53	helpStr += "help - this message\n"
 54	helpStr += "user - display user information (returns name, id, account created, pico+ expiration)\n"
 55	helpStr += "logs - stream user logs\n"
 56	helpStr += "access_logs - fetch access logs from the last 30 days\n"
 57	helpStr += "chat - IRC chat (must enable pty with `-t` to the SSH command)\n"
 58	helpStr += "not-found - return all status 404 requests for a host (hostname.com [year|month])\n"
 59	c.output(helpStr)
 60}
 61
 62func (c *Cmd) user() {
 63	plus := ""
 64	ff, _ := c.Dbpool.FindFeature(c.User.ID, "plus")
 65	if ff != nil {
 66		plus = ff.ExpiresAt.Format(time.RFC3339)
 67	}
 68	helpStr := fmt.Sprintf(
 69		"%s\n%s\n%s\n%s",
 70		c.User.Name,
 71		c.User.ID,
 72		c.User.CreatedAt.Format(time.RFC3339),
 73		plus,
 74	)
 75	c.output(helpStr)
 76}
 77
 78func (c *Cmd) notFound(host, interval string) error {
 79	origin := shared.StartOfYear()
 80	if interval == "month" {
 81		origin = shared.StartOfMonth()
 82	}
 83	c.output(fmt.Sprintf("starting from: %s\n", origin.Format(time.RFC3339)))
 84	urls, err := c.Dbpool.VisitUrlNotFound(&db.SummaryOpts{
 85		Host:   host,
 86		UserID: c.User.ID,
 87		Limit:  100,
 88		Origin: origin,
 89	})
 90	if err != nil {
 91		return err
 92	}
 93	for _, url := range urls {
 94		c.output(fmt.Sprintf("%d %s", url.Count, url.Url))
 95	}
 96	return nil
 97}
 98
 99func (c *Cmd) access_logs(ctx context.Context) error {
100	fromDate := time.Now().AddDate(0, 0, -30)
101	logs, err := c.Dbpool.FindAccessLogs(c.User.ID, &fromDate)
102	if err != nil {
103		return err
104	}
105
106	for _, log := range logs {
107		jsonData, err := json.Marshal(log)
108		if err != nil {
109			c.Log.Error("json marshall", "err", err)
110			continue
111		}
112		c.output(string(jsonData))
113	}
114	return nil
115}
116
117func (c *Cmd) logs(ctx context.Context) error {
118	conn := shared.NewPicoPipeClient()
119	stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
120
121	if err != nil {
122		return err
123	}
124
125	logChan := make(chan string)
126	defer close(logChan)
127
128	go func() {
129		for {
130			select {
131			case <-ctx.Done():
132				return
133			case log, ok := <-logChan:
134				if log == "" {
135					continue
136				}
137				if !ok {
138					return
139				}
140				_, _ = fmt.Fprintln(c.SshSession, log)
141			}
142		}
143	}()
144
145	scanner := bufio.NewScanner(stdoutPipe)
146	scanner.Buffer(make([]byte, 32*1024), 32*1024)
147	for scanner.Scan() {
148		line := scanner.Text()
149		parsedData := map[string]any{}
150
151		err := json.Unmarshal([]byte(line), &parsedData)
152		if err != nil {
153			c.Log.Error("json unmarshal", "err", err, "line", line, "hidden", true)
154			continue
155		}
156
157		user := shared.AnyToStr(parsedData, "user")
158		userId := shared.AnyToStr(parsedData, "userId")
159
160		hidden := shared.AnyToBool(parsedData, "hidden")
161
162		if !hidden && (user == c.User.Name || userId == c.User.ID) {
163			select {
164			case logChan <- line:
165			case <-ctx.Done():
166				return nil
167			default:
168				c.Log.Error("logChan is full, dropping log", "log", line)
169				continue
170			}
171		}
172	}
173	return scanner.Err()
174}
175
176type CliHandler struct {
177	DBPool db.DB
178	Logger *slog.Logger
179}
180
181func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
182	dbpool := handler.DBPool
183	log := handler.Logger
184
185	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
186		return func(sesh *pssh.SSHServerConnSession) error {
187			args := sesh.Command()
188			if len(args) == 0 {
189				return next(sesh)
190			}
191
192			user, err := getUser(sesh, dbpool)
193			if err != nil {
194				_, _ = fmt.Fprintf(sesh.Stderr(), "detected ssh command: %s\n", args)
195				s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
196				sesh.Fatal(s)
197				return s
198			}
199
200			if len(args) > 0 && args[0] == "chat" {
201				_, _, hasPty := sesh.Pty()
202				if !hasPty {
203					err := fmt.Errorf(
204						"in order to render chat you need to enable PTY with the `ssh -t` flag",
205					)
206
207					sesh.Fatal(err)
208					return err
209				}
210
211				ff, err := dbpool.FindFeature(user.ID, "plus")
212				if err != nil {
213					handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
214					ff, err = dbpool.FindFeature(user.ID, "bouncer")
215					if err != nil {
216						handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
217						sesh.Fatal(err)
218						return err
219					}
220				}
221
222				if ff == nil {
223					err = fmt.Errorf("unable to find plus or bouncer feature flag")
224					sesh.Fatal(err)
225					return err
226				}
227
228				pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
229				if err != nil {
230					sesh.Fatal(err)
231					return err
232				}
233				app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
234				if err != nil {
235					sesh.Fatal(err)
236					return err
237				}
238				app.Run()
239				app.Close()
240				return err
241			}
242
243			opts := Cmd{
244				Session:    sesh,
245				SshSession: sesh,
246				User:       user,
247				Log:        log,
248				Dbpool:     dbpool,
249				Write:      false,
250			}
251
252			cmd := strings.TrimSpace(args[0])
253			if len(args) == 1 {
254				switch cmd {
255				case "help":
256					opts.help()
257					return nil
258				case "user":
259					opts.user()
260					return nil
261				case "logs":
262					err = opts.logs(sesh.Context())
263					if err != nil {
264						sesh.Fatal(err)
265					}
266					return nil
267				case "access_logs":
268					err = opts.access_logs(sesh.Context())
269					if err != nil {
270						sesh.Fatal(err)
271					}
272					return nil
273				default:
274					return next(sesh)
275				}
276			}
277
278			if cmd == "not-found" {
279				if len(args) < 3 {
280					sesh.Fatal(fmt.Errorf("must provide host name and interval (`month` or `year`)"))
281					return nil
282				}
283				err = opts.notFound(args[1], args[2])
284				if err != nil {
285					sesh.Fatal(err)
286				}
287				return nil
288			}
289
290			return next(sesh)
291		}
292	}
293}