repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pico
Antonio Mika  ·  2025-04-16

cli.go

  1package pico
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/picosh/pico/pkg/db"
 12	"github.com/picosh/pico/pkg/pssh"
 13	"github.com/picosh/pico/pkg/shared"
 14	"github.com/picosh/utils"
 15
 16	pipeLogger "github.com/picosh/utils/pipe/log"
 17)
 18
 19func getUser(s *pssh.SSHServerConnSession, dbpool db.DB) (*db.User, error) {
 20	if s.PublicKey() == nil {
 21		return nil, fmt.Errorf("key not found")
 22	}
 23
 24	key := utils.KeyForKeyText(s.PublicKey())
 25
 26	user, err := dbpool.FindUserForKey(s.User(), key)
 27	if err != nil {
 28		return nil, err
 29	}
 30
 31	if user.Name == "" {
 32		return nil, fmt.Errorf("must have username set")
 33	}
 34
 35	return user, nil
 36}
 37
 38type Cmd struct {
 39	User       *db.User
 40	SshSession *pssh.SSHServerConnSession
 41	Session    utils.CmdSession
 42	Log        *slog.Logger
 43	Dbpool     db.DB
 44	Write      bool
 45}
 46
 47func (c *Cmd) output(out string) {
 48	_, _ = c.Session.Write([]byte(out + "\r\n"))
 49}
 50
 51func (c *Cmd) help() {
 52	helpStr := "Commands: [help, pico+]\n"
 53	c.output(helpStr)
 54}
 55
 56func (c *Cmd) logs(ctx context.Context) error {
 57	conn := shared.NewPicoPipeClient()
 58	stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
 59
 60	if err != nil {
 61		return err
 62	}
 63
 64	logChan := make(chan string)
 65	defer close(logChan)
 66
 67	go func() {
 68		for {
 69			select {
 70			case <-ctx.Done():
 71				return
 72			case log, ok := <-logChan:
 73				if log == "" {
 74					continue
 75				}
 76				if !ok {
 77					return
 78				}
 79				fmt.Fprintln(c.SshSession, log)
 80			}
 81		}
 82	}()
 83
 84	scanner := bufio.NewScanner(stdoutPipe)
 85	scanner.Buffer(make([]byte, 32*1024), 32*1024)
 86	for scanner.Scan() {
 87		line := scanner.Text()
 88		parsedData := map[string]any{}
 89
 90		err := json.Unmarshal([]byte(line), &parsedData)
 91		if err != nil {
 92			c.Log.Error("json unmarshal", "err", err)
 93			continue
 94		}
 95
 96		user := utils.AnyToStr(parsedData, "user")
 97		userId := utils.AnyToStr(parsedData, "userId")
 98		if user == c.User.Name || userId == c.User.ID {
 99			select {
100			case logChan <- line:
101			case <-ctx.Done():
102				return nil
103			default:
104				c.Log.Error("logChan is full, dropping log", "log", line)
105				continue
106			}
107		}
108	}
109	return scanner.Err()
110}
111
112type CliHandler struct {
113	DBPool db.DB
114	Logger *slog.Logger
115}
116
117func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
118	dbpool := handler.DBPool
119	log := handler.Logger
120
121	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
122		return func(sesh *pssh.SSHServerConnSession) error {
123			args := sesh.Command()
124			if len(args) == 0 {
125				return next(sesh)
126			}
127
128			user, err := getUser(sesh, dbpool)
129			if err != nil {
130				fmt.Fprintf(sesh.Stderr(), "detected ssh command: %s\n", args)
131				s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
132				sesh.Fatal(s)
133				return s
134			}
135
136			if len(args) > 0 && args[0] == "chat" {
137				_, _, hasPty := sesh.Pty()
138				if !hasPty {
139					err := fmt.Errorf(
140						"in order to render chat you need to enable PTY with the `ssh -t` flag",
141					)
142
143					sesh.Fatal(err)
144					return err
145				}
146
147				ff, err := dbpool.FindFeature(user.ID, "plus")
148				if err != nil {
149					handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
150					ff, err = dbpool.FindFeature(user.ID, "bouncer")
151					if err != nil {
152						handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
153						sesh.Fatal(err)
154						return err
155					}
156				}
157
158				if ff == nil {
159					err = fmt.Errorf("unable to find plus or bouncer feature flag")
160					sesh.Fatal(err)
161					return err
162				}
163
164				pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
165				if err != nil {
166					sesh.Fatal(err)
167					return err
168				}
169				app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
170				if err != nil {
171					sesh.Fatal(err)
172					return err
173				}
174				app.Run()
175				app.Close()
176				return err
177			}
178
179			opts := Cmd{
180				Session:    sesh,
181				SshSession: sesh,
182				User:       user,
183				Log:        log,
184				Dbpool:     dbpool,
185				Write:      false,
186			}
187
188			cmd := strings.TrimSpace(args[0])
189			if len(args) == 1 {
190				if cmd == "help" {
191					opts.help()
192					return nil
193				} else if cmd == "logs" {
194					err = opts.logs(sesh.Context())
195					if err != nil {
196						sesh.Fatal(err)
197					}
198					return nil
199				} else {
200					return next(sesh)
201				}
202			}
203
204			return next(sesh)
205		}
206	}
207}