repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2026-03-05

cli_middleware.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8	"time"
  9
 10	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 11	"github.com/picosh/pico/pkg/db"
 12	"github.com/picosh/pico/pkg/pssh"
 13	sendutils "github.com/picosh/pico/pkg/send/utils"
 14)
 15
 16func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
 17	userID, ok := s.Conn.Permissions.Extensions["user_id"]
 18	if !ok {
 19		return nil, fmt.Errorf("`user_id` extension not found")
 20	}
 21	return dbpool.FindUser(userID)
 22}
 23
 24type arrayFlags []string
 25
 26func (i *arrayFlags) String() string {
 27	return "array flags"
 28}
 29
 30func (i *arrayFlags) Set(value string) error {
 31	*i = append(*i, value)
 32	return nil
 33}
 34
 35func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
 36	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 37	cmd.SetOutput(sesh)
 38	write := cmd.Bool("write", false, "apply changes")
 39	return cmd, write
 40}
 41
 42func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 43	_ = cmd.Parse(cmdArgs)
 44
 45	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 46		cmd.Usage()
 47		return false
 48	}
 49	return true
 50}
 51
 52func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
 53	dbpool := handler.Cfg.DB
 54	log := handler.Cfg.Logger
 55	cfg := handler.Cfg
 56	store := handler.Cfg.Storage
 57
 58	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
 59		return func(sesh *pssh.SSHServerConnSession) error {
 60			args := sesh.Command()
 61
 62			// default width and height when no pty
 63			width := 100
 64			height := 24
 65			pty, _, ok := sesh.Pty()
 66			if ok {
 67				width = pty.Window.Width
 68				height = pty.Window.Height
 69			}
 70
 71			opts := Cmd{
 72				Session: sesh,
 73				Store:   store,
 74				Log:     log,
 75				Dbpool:  dbpool,
 76				Write:   false,
 77				Width:   width,
 78				Height:  height,
 79				Cfg:     handler.Cfg,
 80			}
 81
 82			user, err := getUser(sesh, dbpool)
 83			if err != nil {
 84				sendutils.ErrorHandler(sesh, err)
 85				return err
 86			}
 87
 88			opts.User = user
 89
 90			if len(args) == 0 {
 91				opts.help()
 92				return nil
 93			}
 94
 95			cmd := strings.TrimSpace(args[0])
 96			if len(args) == 1 {
 97				switch cmd {
 98				case "help":
 99					opts.help()
100					return nil
101				case "stats":
102					err := opts.stats(cfg.MaxSize)
103					opts.bail(err)
104					return err
105				case "ls":
106					err := opts.ls()
107					opts.bail(err)
108					return err
109				case "cache-all":
110					opts.Write = true
111					err := opts.cacheAll()
112					opts.notice()
113					opts.bail(err)
114					return err
115				default:
116					return next(sesh)
117				}
118			}
119
120			projectName := strings.TrimSpace(args[1])
121			cmdArgs := args[2:]
122			log.Info(
123				"pgs middleware detected command",
124				"args", args,
125				"cmd", cmd,
126				"projectName", projectName,
127				"cmdArgs", cmdArgs,
128			)
129
130			switch cmd {
131			case "fzf":
132				err := opts.fzf(projectName)
133				opts.bail(err)
134				return err
135			case "link":
136				linkCmd, write := flagSet("link", sesh)
137				linkTo := linkCmd.String("to", "", "symbolic link to this project")
138				if !flagCheck(linkCmd, projectName, cmdArgs) {
139					return nil
140				}
141				opts.Write = *write
142
143				if *linkTo == "" {
144					err := fmt.Errorf(
145						"must provide `--to` flag",
146					)
147					opts.bail(err)
148					return err
149				}
150
151				err := opts.link(projectName, *linkTo)
152				opts.notice()
153				if err != nil {
154					opts.bail(err)
155				}
156				return err
157			case "unlink":
158				unlinkCmd, write := flagSet("unlink", sesh)
159				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
160					return nil
161				}
162				opts.Write = *write
163
164				err := opts.unlink(projectName)
165				opts.notice()
166				opts.bail(err)
167				return err
168			case "depends":
169				err := opts.depends(projectName)
170				opts.bail(err)
171				return err
172			case "retain":
173				retainCmd, write := flagSet("retain", sesh)
174				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
175				if !flagCheck(retainCmd, projectName, cmdArgs) {
176					return nil
177				}
178				opts.Write = *write
179
180				err := opts.prune(projectName, *retainNum)
181				opts.notice()
182				opts.bail(err)
183				return err
184			case "prune":
185				pruneCmd, write := flagSet("prune", sesh)
186				if !flagCheck(pruneCmd, projectName, cmdArgs) {
187					return nil
188				}
189				opts.Write = *write
190
191				err := opts.prune(projectName, 0)
192				opts.notice()
193				opts.bail(err)
194				return err
195			case "rm":
196				rmCmd, write := flagSet("rm", sesh)
197				if !flagCheck(rmCmd, projectName, cmdArgs) {
198					return nil
199				}
200				opts.Write = *write
201
202				err := opts.rm(projectName)
203				opts.notice()
204				opts.bail(err)
205				return err
206			case "cache":
207				cacheCmd, write := flagSet("cache", sesh)
208				if !flagCheck(cacheCmd, projectName, cmdArgs) {
209					return nil
210				}
211				opts.Write = *write
212
213				err := opts.cache(projectName)
214				opts.notice()
215				opts.bail(err)
216				return err
217			case "forms":
218				formName := projectName
219				formsCmd, write := flagSet("forms", sesh)
220				rmForm := formsCmd.Bool("rm", false, "delete form data")
221				if !flagCheck(formsCmd, formName, cmdArgs) {
222					return nil
223				}
224				opts.Write = *write
225
226				var err error
227				if formName == "ls" {
228					err = opts.formsLs()
229				} else {
230					if *rmForm {
231						err = opts.formRm(formName)
232						opts.notice()
233					} else {
234						err = opts.formData(formName)
235					}
236				}
237
238				opts.bail(err)
239				return err
240			case "acl":
241				aclCmd, write := flagSet("acl", sesh)
242				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
243				var acls arrayFlags
244				aclCmd.Var(
245					&acls,
246					"acl",
247					"list of pico usernames or sha256 public keys, delimited by commas",
248				)
249				if !flagCheck(aclCmd, projectName, cmdArgs) {
250					return nil
251				}
252				opts.Write = *write
253
254				if !slices.Contains([]string{"public", "pubkeys", "pico", "http-pass"}, *aclType) {
255					err := fmt.Errorf(
256						"acl type must be one of the following: [public, pubkeys, pico], found %s",
257						*aclType,
258					)
259					opts.bail(err)
260					return err
261				}
262
263				hasPicoPlus := false
264				ff, _ := dbpool.FindFeature(user.ID, "plus")
265				if ff != nil {
266					if ff.ExpiresAt.After(time.Now()) {
267						hasPicoPlus = true
268					}
269				}
270
271				if !hasPicoPlus {
272					err = fmt.Errorf("setting acl on a project requires pico+")
273					opts.bail(err)
274					return err
275				}
276
277				if pgsdb.IsProjectPrivate(projectName) {
278					err = fmt.Errorf("projects prefixed with `private-` can *never* have their access changed; however you can symlink to it")
279					opts.bail(err)
280					return err
281				}
282
283				err := opts.acl(projectName, *aclType, acls)
284				opts.notice()
285				opts.bail(err)
286				return err
287			default:
288				return next(sesh)
289			}
290		}
291	}
292}