repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2025-07-13

cli_middleware.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 10	"github.com/picosh/pico/pkg/db"
 11	"github.com/picosh/pico/pkg/pssh"
 12	sendutils "github.com/picosh/pico/pkg/send/utils"
 13)
 14
 15func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
 16	userID, ok := s.Conn.Permissions.Extensions["user_id"]
 17	if !ok {
 18		return nil, fmt.Errorf("`user_id` extension not found")
 19	}
 20	return dbpool.FindUser(userID)
 21}
 22
 23type arrayFlags []string
 24
 25func (i *arrayFlags) String() string {
 26	return "array flags"
 27}
 28
 29func (i *arrayFlags) Set(value string) error {
 30	*i = append(*i, value)
 31	return nil
 32}
 33
 34func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
 35	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 36	cmd.SetOutput(sesh)
 37	write := cmd.Bool("write", false, "apply changes")
 38	return cmd, write
 39}
 40
 41func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 42	_ = cmd.Parse(cmdArgs)
 43
 44	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 45		cmd.Usage()
 46		return false
 47	}
 48	return true
 49}
 50
 51func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
 52	dbpool := handler.Cfg.DB
 53	log := handler.Cfg.Logger
 54	cfg := handler.Cfg
 55	store := handler.Cfg.Storage
 56
 57	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
 58		return func(sesh *pssh.SSHServerConnSession) error {
 59			args := sesh.Command()
 60
 61			// default width and height when no pty
 62			width := 100
 63			height := 24
 64			pty, _, ok := sesh.Pty()
 65			if ok {
 66				width = pty.Window.Width
 67				height = pty.Window.Height
 68			}
 69
 70			opts := Cmd{
 71				Session: sesh,
 72				Store:   store,
 73				Log:     log,
 74				Dbpool:  dbpool,
 75				Write:   false,
 76				Width:   width,
 77				Height:  height,
 78				Cfg:     handler.Cfg,
 79			}
 80
 81			user, err := getUser(sesh, dbpool)
 82			if err != nil {
 83				sendutils.ErrorHandler(sesh, err)
 84				return err
 85			}
 86
 87			opts.User = user
 88
 89			if len(args) == 0 {
 90				opts.help()
 91				return nil
 92			}
 93
 94			cmd := strings.TrimSpace(args[0])
 95			if len(args) == 1 {
 96				switch cmd {
 97				case "help":
 98					opts.help()
 99					return nil
100				case "stats":
101					err := opts.stats(cfg.MaxSize)
102					opts.bail(err)
103					return err
104				case "ls":
105					err := opts.ls()
106					opts.bail(err)
107					return err
108				case "cache-all":
109					opts.Write = true
110					err := opts.cacheAll()
111					opts.notice()
112					opts.bail(err)
113					return err
114				default:
115					return next(sesh)
116				}
117			}
118
119			projectName := strings.TrimSpace(args[1])
120			cmdArgs := args[2:]
121			log.Info(
122				"pgs middleware detected command",
123				"args", args,
124				"cmd", cmd,
125				"projectName", projectName,
126				"cmdArgs", cmdArgs,
127			)
128
129			switch cmd {
130			case "fzf":
131				err := opts.fzf(projectName)
132				opts.bail(err)
133				return err
134			case "link":
135				linkCmd, write := flagSet("link", sesh)
136				linkTo := linkCmd.String("to", "", "symbolic link to this project")
137				if !flagCheck(linkCmd, projectName, cmdArgs) {
138					return nil
139				}
140				opts.Write = *write
141
142				if *linkTo == "" {
143					err := fmt.Errorf(
144						"must provide `--to` flag",
145					)
146					opts.bail(err)
147					return err
148				}
149
150				err := opts.link(projectName, *linkTo)
151				opts.notice()
152				if err != nil {
153					opts.bail(err)
154				}
155				return err
156			case "unlink":
157				unlinkCmd, write := flagSet("unlink", sesh)
158				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
159					return nil
160				}
161				opts.Write = *write
162
163				err := opts.unlink(projectName)
164				opts.notice()
165				opts.bail(err)
166				return err
167			case "depends":
168				err := opts.depends(projectName)
169				opts.bail(err)
170				return err
171			case "retain":
172				retainCmd, write := flagSet("retain", sesh)
173				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
174				if !flagCheck(retainCmd, projectName, cmdArgs) {
175					return nil
176				}
177				opts.Write = *write
178
179				err := opts.prune(projectName, *retainNum)
180				opts.notice()
181				opts.bail(err)
182				return err
183			case "prune":
184				pruneCmd, write := flagSet("prune", sesh)
185				if !flagCheck(pruneCmd, projectName, cmdArgs) {
186					return nil
187				}
188				opts.Write = *write
189
190				err := opts.prune(projectName, 0)
191				opts.notice()
192				opts.bail(err)
193				return err
194			case "rm":
195				rmCmd, write := flagSet("rm", sesh)
196				if !flagCheck(rmCmd, projectName, cmdArgs) {
197					return nil
198				}
199				opts.Write = *write
200
201				err := opts.rm(projectName)
202				opts.notice()
203				opts.bail(err)
204				return err
205			case "cache":
206				cacheCmd, write := flagSet("cache", sesh)
207				if !flagCheck(cacheCmd, projectName, cmdArgs) {
208					return nil
209				}
210				opts.Write = *write
211
212				err := opts.cache(projectName)
213				opts.notice()
214				opts.bail(err)
215				return err
216			case "acl":
217				aclCmd, write := flagSet("acl", sesh)
218				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
219				var acls arrayFlags
220				aclCmd.Var(
221					&acls,
222					"acl",
223					"list of pico usernames or sha256 public keys, delimited by commas",
224				)
225				if !flagCheck(aclCmd, projectName, cmdArgs) {
226					return nil
227				}
228				opts.Write = *write
229
230				if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
231					err := fmt.Errorf(
232						"acl type must be one of the following: [public, pubkeys, pico], found %s",
233						*aclType,
234					)
235					opts.bail(err)
236					return err
237				}
238
239				err := opts.acl(projectName, *aclType, acls)
240				opts.notice()
241				opts.bail(err)
242				return err
243			default:
244				return next(sesh)
245			}
246		}
247	}
248}