repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ยท  2025-03-28

cli_middleware.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 10	"github.com/picosh/pico/pkg/db"
 11	"github.com/picosh/pico/pkg/pssh"
 12	sendutils "github.com/picosh/pico/pkg/send/utils"
 13)
 14
 15func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
 16	userID, ok := s.Conn.Permissions.Extensions["user_id"]
 17	if !ok {
 18		return nil, fmt.Errorf("`user_id` extension not found")
 19	}
 20	return dbpool.FindUser(userID)
 21}
 22
 23type arrayFlags []string
 24
 25func (i *arrayFlags) String() string {
 26	return "array flags"
 27}
 28
 29func (i *arrayFlags) Set(value string) error {
 30	*i = append(*i, value)
 31	return nil
 32}
 33
 34func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
 35	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 36	cmd.SetOutput(sesh)
 37	write := cmd.Bool("write", false, "apply changes")
 38	return cmd, write
 39}
 40
 41func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 42	_ = cmd.Parse(cmdArgs)
 43
 44	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 45		cmd.Usage()
 46		return false
 47	}
 48	return true
 49}
 50
 51func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
 52	dbpool := handler.Cfg.DB
 53	log := handler.Cfg.Logger
 54	cfg := handler.Cfg
 55	store := handler.Cfg.Storage
 56
 57	return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
 58		return func(sesh *pssh.SSHServerConnSession) error {
 59			args := sesh.Command()
 60
 61			// default width and height when no pty
 62			width := 100
 63			height := 24
 64			pty, _, ok := sesh.Pty()
 65			if ok {
 66				width = pty.Window.Width
 67				height = pty.Window.Height
 68			}
 69
 70			user, err := getUser(sesh, dbpool)
 71			if err != nil {
 72				sendutils.ErrorHandler(sesh, err)
 73				return err
 74			}
 75
 76			// renderer := bm.MakeRenderer(sesh)
 77			// renderer.SetColorProfile(termenv.TrueColor)
 78			// styles := common.DefaultStyles(renderer)
 79
 80			opts := Cmd{
 81				Session: sesh,
 82				User:    user,
 83				Store:   store,
 84				Log:     log,
 85				Dbpool:  dbpool,
 86				Write:   false,
 87				// Styles:  styles,
 88				Width:  width,
 89				Height: height,
 90				Cfg:    handler.Cfg,
 91			}
 92
 93			if len(args) == 0 {
 94				opts.help()
 95				return nil
 96			}
 97
 98			cmd := strings.TrimSpace(args[0])
 99			if len(args) == 1 {
100				if cmd == "help" {
101					opts.help()
102					return nil
103				} else if cmd == "stats" {
104					err := opts.stats(cfg.MaxSize)
105					opts.bail(err)
106					return err
107				} else if cmd == "ls" {
108					err := opts.ls()
109					opts.bail(err)
110					return err
111				} else if cmd == "cache-all" {
112					opts.Write = true
113					err := opts.cacheAll()
114					opts.notice()
115					opts.bail(err)
116					return err
117				} else {
118					return next(sesh)
119				}
120			}
121
122			projectName := strings.TrimSpace(args[1])
123			cmdArgs := args[2:]
124			log.Info(
125				"pgs middleware detected command",
126				"args", args,
127				"cmd", cmd,
128				"projectName", projectName,
129				"cmdArgs", cmdArgs,
130			)
131
132			if cmd == "fzf" {
133				err := opts.fzf(projectName)
134				opts.bail(err)
135				return err
136			} else if cmd == "link" {
137				linkCmd, write := flagSet("link", sesh)
138				linkTo := linkCmd.String("to", "", "symbolic link to this project")
139				if !flagCheck(linkCmd, projectName, cmdArgs) {
140					return nil
141				}
142				opts.Write = *write
143
144				if *linkTo == "" {
145					err := fmt.Errorf(
146						"must provide `--to` flag",
147					)
148					opts.bail(err)
149					return err
150				}
151
152				err := opts.link(projectName, *linkTo)
153				opts.notice()
154				if err != nil {
155					opts.bail(err)
156				}
157				return err
158			} else if cmd == "unlink" {
159				unlinkCmd, write := flagSet("unlink", sesh)
160				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
161					return nil
162				}
163				opts.Write = *write
164
165				err := opts.unlink(projectName)
166				opts.notice()
167				opts.bail(err)
168				return err
169			} else if cmd == "depends" {
170				err := opts.depends(projectName)
171				opts.bail(err)
172				return err
173			} else if cmd == "retain" {
174				retainCmd, write := flagSet("retain", sesh)
175				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
176				if !flagCheck(retainCmd, projectName, cmdArgs) {
177					return nil
178				}
179				opts.Write = *write
180
181				err := opts.prune(projectName, *retainNum)
182				opts.notice()
183				opts.bail(err)
184				return err
185			} else if cmd == "prune" {
186				pruneCmd, write := flagSet("prune", sesh)
187				if !flagCheck(pruneCmd, projectName, cmdArgs) {
188					return nil
189				}
190				opts.Write = *write
191
192				err := opts.prune(projectName, 0)
193				opts.notice()
194				opts.bail(err)
195				return err
196			} else if cmd == "rm" {
197				rmCmd, write := flagSet("rm", sesh)
198				if !flagCheck(rmCmd, projectName, cmdArgs) {
199					return nil
200				}
201				opts.Write = *write
202
203				err := opts.rm(projectName)
204				opts.notice()
205				opts.bail(err)
206				return err
207			} else if cmd == "cache" {
208				cacheCmd, write := flagSet("cache", sesh)
209				if !flagCheck(cacheCmd, projectName, cmdArgs) {
210					return nil
211				}
212				opts.Write = *write
213
214				err := opts.cache(projectName)
215				opts.notice()
216				opts.bail(err)
217				return err
218			} else if cmd == "acl" {
219				aclCmd, write := flagSet("acl", sesh)
220				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
221				var acls arrayFlags
222				aclCmd.Var(
223					&acls,
224					"acl",
225					"list of pico usernames or sha256 public keys, delimited by commas",
226				)
227				if !flagCheck(aclCmd, projectName, cmdArgs) {
228					return nil
229				}
230				opts.Write = *write
231
232				if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
233					err := fmt.Errorf(
234						"acl type must be one of the following: [public, pubkeys, pico], found %s",
235						*aclType,
236					)
237					opts.bail(err)
238					return err
239				}
240
241				err := opts.acl(projectName, *aclType, acls)
242				opts.notice()
243				opts.bail(err)
244				return err
245			} else {
246				return next(sesh)
247			}
248		}
249	}
250}