Eric Bower
·
2025-07-13
cli_middleware.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8
9 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
10 "github.com/picosh/pico/pkg/db"
11 "github.com/picosh/pico/pkg/pssh"
12 sendutils "github.com/picosh/pico/pkg/send/utils"
13)
14
15func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
16 userID, ok := s.Conn.Permissions.Extensions["user_id"]
17 if !ok {
18 return nil, fmt.Errorf("`user_id` extension not found")
19 }
20 return dbpool.FindUser(userID)
21}
22
23type arrayFlags []string
24
25func (i *arrayFlags) String() string {
26 return "array flags"
27}
28
29func (i *arrayFlags) Set(value string) error {
30 *i = append(*i, value)
31 return nil
32}
33
34func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
35 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
36 cmd.SetOutput(sesh)
37 write := cmd.Bool("write", false, "apply changes")
38 return cmd, write
39}
40
41func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
42 _ = cmd.Parse(cmdArgs)
43
44 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
45 cmd.Usage()
46 return false
47 }
48 return true
49}
50
51func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
52 dbpool := handler.Cfg.DB
53 log := handler.Cfg.Logger
54 cfg := handler.Cfg
55 store := handler.Cfg.Storage
56
57 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
58 return func(sesh *pssh.SSHServerConnSession) error {
59 args := sesh.Command()
60
61 // default width and height when no pty
62 width := 100
63 height := 24
64 pty, _, ok := sesh.Pty()
65 if ok {
66 width = pty.Window.Width
67 height = pty.Window.Height
68 }
69
70 opts := Cmd{
71 Session: sesh,
72 Store: store,
73 Log: log,
74 Dbpool: dbpool,
75 Write: false,
76 Width: width,
77 Height: height,
78 Cfg: handler.Cfg,
79 }
80
81 user, err := getUser(sesh, dbpool)
82 if err != nil {
83 sendutils.ErrorHandler(sesh, err)
84 return err
85 }
86
87 opts.User = user
88
89 if len(args) == 0 {
90 opts.help()
91 return nil
92 }
93
94 cmd := strings.TrimSpace(args[0])
95 if len(args) == 1 {
96 switch cmd {
97 case "help":
98 opts.help()
99 return nil
100 case "stats":
101 err := opts.stats(cfg.MaxSize)
102 opts.bail(err)
103 return err
104 case "ls":
105 err := opts.ls()
106 opts.bail(err)
107 return err
108 case "cache-all":
109 opts.Write = true
110 err := opts.cacheAll()
111 opts.notice()
112 opts.bail(err)
113 return err
114 default:
115 return next(sesh)
116 }
117 }
118
119 projectName := strings.TrimSpace(args[1])
120 cmdArgs := args[2:]
121 log.Info(
122 "pgs middleware detected command",
123 "args", args,
124 "cmd", cmd,
125 "projectName", projectName,
126 "cmdArgs", cmdArgs,
127 )
128
129 switch cmd {
130 case "fzf":
131 err := opts.fzf(projectName)
132 opts.bail(err)
133 return err
134 case "link":
135 linkCmd, write := flagSet("link", sesh)
136 linkTo := linkCmd.String("to", "", "symbolic link to this project")
137 if !flagCheck(linkCmd, projectName, cmdArgs) {
138 return nil
139 }
140 opts.Write = *write
141
142 if *linkTo == "" {
143 err := fmt.Errorf(
144 "must provide `--to` flag",
145 )
146 opts.bail(err)
147 return err
148 }
149
150 err := opts.link(projectName, *linkTo)
151 opts.notice()
152 if err != nil {
153 opts.bail(err)
154 }
155 return err
156 case "unlink":
157 unlinkCmd, write := flagSet("unlink", sesh)
158 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
159 return nil
160 }
161 opts.Write = *write
162
163 err := opts.unlink(projectName)
164 opts.notice()
165 opts.bail(err)
166 return err
167 case "depends":
168 err := opts.depends(projectName)
169 opts.bail(err)
170 return err
171 case "retain":
172 retainCmd, write := flagSet("retain", sesh)
173 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
174 if !flagCheck(retainCmd, projectName, cmdArgs) {
175 return nil
176 }
177 opts.Write = *write
178
179 err := opts.prune(projectName, *retainNum)
180 opts.notice()
181 opts.bail(err)
182 return err
183 case "prune":
184 pruneCmd, write := flagSet("prune", sesh)
185 if !flagCheck(pruneCmd, projectName, cmdArgs) {
186 return nil
187 }
188 opts.Write = *write
189
190 err := opts.prune(projectName, 0)
191 opts.notice()
192 opts.bail(err)
193 return err
194 case "rm":
195 rmCmd, write := flagSet("rm", sesh)
196 if !flagCheck(rmCmd, projectName, cmdArgs) {
197 return nil
198 }
199 opts.Write = *write
200
201 err := opts.rm(projectName)
202 opts.notice()
203 opts.bail(err)
204 return err
205 case "cache":
206 cacheCmd, write := flagSet("cache", sesh)
207 if !flagCheck(cacheCmd, projectName, cmdArgs) {
208 return nil
209 }
210 opts.Write = *write
211
212 err := opts.cache(projectName)
213 opts.notice()
214 opts.bail(err)
215 return err
216 case "acl":
217 aclCmd, write := flagSet("acl", sesh)
218 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
219 var acls arrayFlags
220 aclCmd.Var(
221 &acls,
222 "acl",
223 "list of pico usernames or sha256 public keys, delimited by commas",
224 )
225 if !flagCheck(aclCmd, projectName, cmdArgs) {
226 return nil
227 }
228 opts.Write = *write
229
230 if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
231 err := fmt.Errorf(
232 "acl type must be one of the following: [public, pubkeys, pico], found %s",
233 *aclType,
234 )
235 opts.bail(err)
236 return err
237 }
238
239 err := opts.acl(projectName, *aclType, acls)
240 opts.notice()
241 opts.bail(err)
242 return err
243 default:
244 return next(sesh)
245 }
246 }
247 }
248}