repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pico
Eric Bower  ·  2025-06-28

cli.go

  1package pico
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"time"
 11
 12	"github.com/picosh/pico/pkg/db"
 13	"github.com/picosh/pico/pkg/pssh"
 14	"github.com/picosh/pico/pkg/shared"
 15	"github.com/picosh/utils"
 16
 17	pipeLogger "github.com/picosh/utils/pipe/log"
 18)
 19
 20func getUser(s *pssh.SSHServerConnSession, dbpool db.DB) (*db.User, error) {
 21	if s.PublicKey() == nil {
 22		return nil, fmt.Errorf("key not found")
 23	}
 24
 25	key := utils.KeyForKeyText(s.PublicKey())
 26
 27	user, err := dbpool.FindUserForKey(s.User(), key)
 28	if err != nil {
 29		return nil, err
 30	}
 31
 32	if user.Name == "" {
 33		return nil, fmt.Errorf("must have username set")
 34	}
 35
 36	return user, nil
 37}
 38
 39type Cmd struct {
 40	User       *db.User
 41	SshSession *pssh.SSHServerConnSession
 42	Session    utils.CmdSession
 43	Log        *slog.Logger
 44	Dbpool     db.DB
 45	Write      bool
 46}
 47
 48func (c *Cmd) output(out string) {
 49	_, _ = c.Session.Write([]byte(out + "\r\n"))
 50}
 51
 52func (c *Cmd) help() {
 53	helpStr := "Commands: [help, user, logs, chat, not-found]\n"
 54	helpStr += "help - this message\n"
 55	helpStr += "user - display user information (returns name, id, account created, pico+ expiration)\n"
 56	helpStr += "logs - stream user logs\n"
 57	helpStr += "chat - IRC chat (must enable pty with `-t` to the SSH command)\n"
 58	helpStr += "not-found - return all status 404 requests for a host (hostname.com [year|month])\n"
 59	c.output(helpStr)
 60}
 61
 62func (c *Cmd) user() {
 63	plus := ""
 64	ff, _ := c.Dbpool.FindFeature(c.User.ID, "plus")
 65	if ff != nil {
 66		plus = ff.ExpiresAt.Format(time.RFC3339)
 67	}
 68	helpStr := fmt.Sprintf(
 69		"%s\n%s\n%s\n%s",
 70		c.User.Name,
 71		c.User.ID,
 72		c.User.CreatedAt.Format(time.RFC3339),
 73		plus,
 74	)
 75	c.output(helpStr)
 76}
 77
 78func (c *Cmd) notFound(host, interval string) error {
 79	origin := utils.StartOfYear()
 80	if interval == "month" {
 81		origin = utils.StartOfMonth()
 82	}
 83	c.output(fmt.Sprintf("starting from: %s\n", origin.Format(time.RFC3339)))
 84	urls, err := c.Dbpool.VisitUrlNotFound(&db.SummaryOpts{
 85		Host:   host,
 86		UserID: c.User.ID,
 87		Limit:  100,
 88		Origin: origin,
 89	})
 90	if err != nil {
 91		return err
 92	}
 93	for _, url := range urls {
 94		c.output(fmt.Sprintf("%d %s", url.Count, url.Url))
 95	}
 96	return nil
 97}
 98
 99func (c *Cmd) logs(ctx context.Context) error {
100	conn := shared.NewPicoPipeClient()
101	stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
102
103	if err != nil {
104		return err
105	}
106
107	logChan := make(chan string)
108	defer close(logChan)
109
110	go func() {
111		for {
112			select {
113			case <-ctx.Done():
114				return
115			case log, ok := <-logChan:
116				if log == "" {
117					continue
118				}
119				if !ok {
120					return
121				}
122				_, _ = fmt.Fprintln(c.SshSession, log)
123			}
124		}
125	}()
126
127	scanner := bufio.NewScanner(stdoutPipe)
128	scanner.Buffer(make([]byte, 32*1024), 32*1024)
129	for scanner.Scan() {
130		line := scanner.Text()
131		parsedData := map[string]any{}
132
133		err := json.Unmarshal([]byte(line), &parsedData)
134		if err != nil {
135			c.Log.Error("json unmarshal", "err", err, "line", line, "hidden", true)
136			continue
137		}
138
139		user := utils.AnyToStr(parsedData, "user")
140		userId := utils.AnyToStr(parsedData, "userId")
141
142		hidden := utils.AnyToBool(parsedData, "hidden")
143
144		if !hidden && (user == c.User.Name || userId == c.User.ID) {
145			select {
146			case logChan <- line:
147			case <-ctx.Done():
148				return nil
149			default:
150				c.Log.Error("logChan is full, dropping log", "log", line)
151				continue
152			}
153		}
154	}
155	return scanner.Err()
156}
157
158type CliHandler struct {
159	DBPool db.DB
160	Logger *slog.Logger
161}
162
163func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
164	dbpool := handler.DBPool
165	log := handler.Logger
166
167	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
168		return func(sesh *pssh.SSHServerConnSession) error {
169			args := sesh.Command()
170			if len(args) == 0 {
171				return next(sesh)
172			}
173
174			user, err := getUser(sesh, dbpool)
175			if err != nil {
176				_, _ = fmt.Fprintf(sesh.Stderr(), "detected ssh command: %s\n", args)
177				s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
178				sesh.Fatal(s)
179				return s
180			}
181
182			if len(args) > 0 && args[0] == "chat" {
183				_, _, hasPty := sesh.Pty()
184				if !hasPty {
185					err := fmt.Errorf(
186						"in order to render chat you need to enable PTY with the `ssh -t` flag",
187					)
188
189					sesh.Fatal(err)
190					return err
191				}
192
193				ff, err := dbpool.FindFeature(user.ID, "plus")
194				if err != nil {
195					handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
196					ff, err = dbpool.FindFeature(user.ID, "bouncer")
197					if err != nil {
198						handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
199						sesh.Fatal(err)
200						return err
201					}
202				}
203
204				if ff == nil {
205					err = fmt.Errorf("unable to find plus or bouncer feature flag")
206					sesh.Fatal(err)
207					return err
208				}
209
210				pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
211				if err != nil {
212					sesh.Fatal(err)
213					return err
214				}
215				app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
216				if err != nil {
217					sesh.Fatal(err)
218					return err
219				}
220				app.Run()
221				app.Close()
222				return err
223			}
224
225			opts := Cmd{
226				Session:    sesh,
227				SshSession: sesh,
228				User:       user,
229				Log:        log,
230				Dbpool:     dbpool,
231				Write:      false,
232			}
233
234			cmd := strings.TrimSpace(args[0])
235			if len(args) == 1 {
236				switch cmd {
237				case "help":
238					opts.help()
239					return nil
240				case "user":
241					opts.user()
242					return nil
243				case "logs":
244					err = opts.logs(sesh.Context())
245					if err != nil {
246						sesh.Fatal(err)
247					}
248					return nil
249				default:
250					return next(sesh)
251				}
252			}
253
254			if cmd == "not-found" {
255				if len(args) < 3 {
256					sesh.Fatal(fmt.Errorf("must provide host name and interval (`month` or `year`)"))
257					return nil
258				}
259				err = opts.notFound(args[1], args[2])
260				if err != nil {
261					sesh.Fatal(err)
262				}
263				return nil
264			}
265
266			return next(sesh)
267		}
268	}
269}