repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2025-05-24

cli_middleware.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 10	"github.com/picosh/pico/pkg/db"
 11	"github.com/picosh/pico/pkg/pssh"
 12	sendutils "github.com/picosh/pico/pkg/send/utils"
 13)
 14
 15func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
 16	userID, ok := s.Conn.Permissions.Extensions["user_id"]
 17	if !ok {
 18		return nil, fmt.Errorf("`user_id` extension not found")
 19	}
 20	return dbpool.FindUser(userID)
 21}
 22
 23type arrayFlags []string
 24
 25func (i *arrayFlags) String() string {
 26	return "array flags"
 27}
 28
 29func (i *arrayFlags) Set(value string) error {
 30	*i = append(*i, value)
 31	return nil
 32}
 33
 34func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
 35	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 36	cmd.SetOutput(sesh)
 37	write := cmd.Bool("write", false, "apply changes")
 38	return cmd, write
 39}
 40
 41func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 42	_ = cmd.Parse(cmdArgs)
 43
 44	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 45		cmd.Usage()
 46		return false
 47	}
 48	return true
 49}
 50
 51func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
 52	dbpool := handler.Cfg.DB
 53	log := handler.Cfg.Logger
 54	cfg := handler.Cfg
 55	store := handler.Cfg.Storage
 56
 57	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
 58		return func(sesh *pssh.SSHServerConnSession) error {
 59			args := sesh.Command()
 60
 61			// default width and height when no pty
 62			width := 100
 63			height := 24
 64			pty, _, ok := sesh.Pty()
 65			if ok {
 66				width = pty.Window.Width
 67				height = pty.Window.Height
 68			}
 69
 70			user, err := getUser(sesh, dbpool)
 71			if err != nil {
 72				sendutils.ErrorHandler(sesh, err)
 73				return err
 74			}
 75
 76			// renderer := bm.MakeRenderer(sesh)
 77			// renderer.SetColorProfile(termenv.TrueColor)
 78			// styles := common.DefaultStyles(renderer)
 79
 80			opts := Cmd{
 81				Session: sesh,
 82				User:    user,
 83				Store:   store,
 84				Log:     log,
 85				Dbpool:  dbpool,
 86				Write:   false,
 87				// Styles:  styles,
 88				Width:  width,
 89				Height: height,
 90				Cfg:    handler.Cfg,
 91			}
 92
 93			if len(args) == 0 {
 94				opts.help()
 95				return nil
 96			}
 97
 98			cmd := strings.TrimSpace(args[0])
 99			if len(args) == 1 {
100				switch cmd {
101				case "help":
102					opts.help()
103					return nil
104				case "stats":
105					err := opts.stats(cfg.MaxSize)
106					opts.bail(err)
107					return err
108				case "ls":
109					err := opts.ls()
110					opts.bail(err)
111					return err
112				case "cache-all":
113					opts.Write = true
114					err := opts.cacheAll()
115					opts.notice()
116					opts.bail(err)
117					return err
118				default:
119					return next(sesh)
120				}
121			}
122
123			projectName := strings.TrimSpace(args[1])
124			cmdArgs := args[2:]
125			log.Info(
126				"pgs middleware detected command",
127				"args", args,
128				"cmd", cmd,
129				"projectName", projectName,
130				"cmdArgs", cmdArgs,
131			)
132
133			switch cmd {
134			case "fzf":
135				err := opts.fzf(projectName)
136				opts.bail(err)
137				return err
138			case "link":
139				linkCmd, write := flagSet("link", sesh)
140				linkTo := linkCmd.String("to", "", "symbolic link to this project")
141				if !flagCheck(linkCmd, projectName, cmdArgs) {
142					return nil
143				}
144				opts.Write = *write
145
146				if *linkTo == "" {
147					err := fmt.Errorf(
148						"must provide `--to` flag",
149					)
150					opts.bail(err)
151					return err
152				}
153
154				err := opts.link(projectName, *linkTo)
155				opts.notice()
156				if err != nil {
157					opts.bail(err)
158				}
159				return err
160			case "unlink":
161				unlinkCmd, write := flagSet("unlink", sesh)
162				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
163					return nil
164				}
165				opts.Write = *write
166
167				err := opts.unlink(projectName)
168				opts.notice()
169				opts.bail(err)
170				return err
171			case "depends":
172				err := opts.depends(projectName)
173				opts.bail(err)
174				return err
175			case "retain":
176				retainCmd, write := flagSet("retain", sesh)
177				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
178				if !flagCheck(retainCmd, projectName, cmdArgs) {
179					return nil
180				}
181				opts.Write = *write
182
183				err := opts.prune(projectName, *retainNum)
184				opts.notice()
185				opts.bail(err)
186				return err
187			case "prune":
188				pruneCmd, write := flagSet("prune", sesh)
189				if !flagCheck(pruneCmd, projectName, cmdArgs) {
190					return nil
191				}
192				opts.Write = *write
193
194				err := opts.prune(projectName, 0)
195				opts.notice()
196				opts.bail(err)
197				return err
198			case "rm":
199				rmCmd, write := flagSet("rm", sesh)
200				if !flagCheck(rmCmd, projectName, cmdArgs) {
201					return nil
202				}
203				opts.Write = *write
204
205				err := opts.rm(projectName)
206				opts.notice()
207				opts.bail(err)
208				return err
209			case "cache":
210				cacheCmd, write := flagSet("cache", sesh)
211				if !flagCheck(cacheCmd, projectName, cmdArgs) {
212					return nil
213				}
214				opts.Write = *write
215
216				err := opts.cache(projectName)
217				opts.notice()
218				opts.bail(err)
219				return err
220			case "acl":
221				aclCmd, write := flagSet("acl", sesh)
222				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
223				var acls arrayFlags
224				aclCmd.Var(
225					&acls,
226					"acl",
227					"list of pico usernames or sha256 public keys, delimited by commas",
228				)
229				if !flagCheck(aclCmd, projectName, cmdArgs) {
230					return nil
231				}
232				opts.Write = *write
233
234				if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
235					err := fmt.Errorf(
236						"acl type must be one of the following: [public, pubkeys, pico], found %s",
237						*aclType,
238					)
239					opts.bail(err)
240					return err
241				}
242
243				err := opts.acl(projectName, *aclType, acls)
244				opts.notice()
245				opts.bail(err)
246				return err
247			default:
248				return next(sesh)
249			}
250		}
251	}
252}