Eric Bower
·
2026-03-05
cli_middleware.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8 "time"
9
10 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
11 "github.com/picosh/pico/pkg/db"
12 "github.com/picosh/pico/pkg/pssh"
13 sendutils "github.com/picosh/pico/pkg/send/utils"
14)
15
16func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
17 userID, ok := s.Conn.Permissions.Extensions["user_id"]
18 if !ok {
19 return nil, fmt.Errorf("`user_id` extension not found")
20 }
21 return dbpool.FindUser(userID)
22}
23
24type arrayFlags []string
25
26func (i *arrayFlags) String() string {
27 return "array flags"
28}
29
30func (i *arrayFlags) Set(value string) error {
31 *i = append(*i, value)
32 return nil
33}
34
35func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
36 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
37 cmd.SetOutput(sesh)
38 write := cmd.Bool("write", false, "apply changes")
39 return cmd, write
40}
41
42func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
43 _ = cmd.Parse(cmdArgs)
44
45 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
46 cmd.Usage()
47 return false
48 }
49 return true
50}
51
52func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
53 dbpool := handler.Cfg.DB
54 log := handler.Cfg.Logger
55 cfg := handler.Cfg
56 store := handler.Cfg.Storage
57
58 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
59 return func(sesh *pssh.SSHServerConnSession) error {
60 args := sesh.Command()
61
62 // default width and height when no pty
63 width := 100
64 height := 24
65 pty, _, ok := sesh.Pty()
66 if ok {
67 width = pty.Window.Width
68 height = pty.Window.Height
69 }
70
71 opts := Cmd{
72 Session: sesh,
73 Store: store,
74 Log: log,
75 Dbpool: dbpool,
76 Write: false,
77 Width: width,
78 Height: height,
79 Cfg: handler.Cfg,
80 }
81
82 user, err := getUser(sesh, dbpool)
83 if err != nil {
84 sendutils.ErrorHandler(sesh, err)
85 return err
86 }
87
88 opts.User = user
89
90 if len(args) == 0 {
91 opts.help()
92 return nil
93 }
94
95 cmd := strings.TrimSpace(args[0])
96 if len(args) == 1 {
97 switch cmd {
98 case "help":
99 opts.help()
100 return nil
101 case "stats":
102 err := opts.stats(cfg.MaxSize)
103 opts.bail(err)
104 return err
105 case "ls":
106 err := opts.ls()
107 opts.bail(err)
108 return err
109 case "cache-all":
110 opts.Write = true
111 err := opts.cacheAll()
112 opts.notice()
113 opts.bail(err)
114 return err
115 default:
116 return next(sesh)
117 }
118 }
119
120 projectName := strings.TrimSpace(args[1])
121 cmdArgs := args[2:]
122 log.Info(
123 "pgs middleware detected command",
124 "args", args,
125 "cmd", cmd,
126 "projectName", projectName,
127 "cmdArgs", cmdArgs,
128 )
129
130 switch cmd {
131 case "fzf":
132 err := opts.fzf(projectName)
133 opts.bail(err)
134 return err
135 case "link":
136 linkCmd, write := flagSet("link", sesh)
137 linkTo := linkCmd.String("to", "", "symbolic link to this project")
138 if !flagCheck(linkCmd, projectName, cmdArgs) {
139 return nil
140 }
141 opts.Write = *write
142
143 if *linkTo == "" {
144 err := fmt.Errorf(
145 "must provide `--to` flag",
146 )
147 opts.bail(err)
148 return err
149 }
150
151 err := opts.link(projectName, *linkTo)
152 opts.notice()
153 if err != nil {
154 opts.bail(err)
155 }
156 return err
157 case "unlink":
158 unlinkCmd, write := flagSet("unlink", sesh)
159 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
160 return nil
161 }
162 opts.Write = *write
163
164 err := opts.unlink(projectName)
165 opts.notice()
166 opts.bail(err)
167 return err
168 case "depends":
169 err := opts.depends(projectName)
170 opts.bail(err)
171 return err
172 case "retain":
173 retainCmd, write := flagSet("retain", sesh)
174 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
175 if !flagCheck(retainCmd, projectName, cmdArgs) {
176 return nil
177 }
178 opts.Write = *write
179
180 err := opts.prune(projectName, *retainNum)
181 opts.notice()
182 opts.bail(err)
183 return err
184 case "prune":
185 pruneCmd, write := flagSet("prune", sesh)
186 if !flagCheck(pruneCmd, projectName, cmdArgs) {
187 return nil
188 }
189 opts.Write = *write
190
191 err := opts.prune(projectName, 0)
192 opts.notice()
193 opts.bail(err)
194 return err
195 case "rm":
196 rmCmd, write := flagSet("rm", sesh)
197 if !flagCheck(rmCmd, projectName, cmdArgs) {
198 return nil
199 }
200 opts.Write = *write
201
202 err := opts.rm(projectName)
203 opts.notice()
204 opts.bail(err)
205 return err
206 case "cache":
207 cacheCmd, write := flagSet("cache", sesh)
208 if !flagCheck(cacheCmd, projectName, cmdArgs) {
209 return nil
210 }
211 opts.Write = *write
212
213 err := opts.cache(projectName)
214 opts.notice()
215 opts.bail(err)
216 return err
217 case "forms":
218 formName := projectName
219 formsCmd, write := flagSet("forms", sesh)
220 rmForm := formsCmd.Bool("rm", false, "delete form data")
221 if !flagCheck(formsCmd, formName, cmdArgs) {
222 return nil
223 }
224 opts.Write = *write
225
226 var err error
227 if formName == "ls" {
228 err = opts.formsLs()
229 } else {
230 if *rmForm {
231 err = opts.formRm(formName)
232 opts.notice()
233 } else {
234 err = opts.formData(formName)
235 }
236 }
237
238 opts.bail(err)
239 return err
240 case "acl":
241 aclCmd, write := flagSet("acl", sesh)
242 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
243 var acls arrayFlags
244 aclCmd.Var(
245 &acls,
246 "acl",
247 "list of pico usernames or sha256 public keys, delimited by commas",
248 )
249 if !flagCheck(aclCmd, projectName, cmdArgs) {
250 return nil
251 }
252 opts.Write = *write
253
254 if !slices.Contains([]string{"public", "pubkeys", "pico", "http-pass"}, *aclType) {
255 err := fmt.Errorf(
256 "acl type must be one of the following: [public, pubkeys, pico], found %s",
257 *aclType,
258 )
259 opts.bail(err)
260 return err
261 }
262
263 hasPicoPlus := false
264 ff, _ := dbpool.FindFeature(user.ID, "plus")
265 if ff != nil {
266 if ff.ExpiresAt.After(time.Now()) {
267 hasPicoPlus = true
268 }
269 }
270
271 if !hasPicoPlus {
272 err = fmt.Errorf("setting acl on a project requires pico+")
273 opts.bail(err)
274 return err
275 }
276
277 if pgsdb.IsProjectPrivate(projectName) {
278 err = fmt.Errorf("projects prefixed with `private-` can *never* have their access changed; however you can symlink to it")
279 opts.bail(err)
280 return err
281 }
282
283 err := opts.acl(projectName, *aclType, acls)
284 opts.notice()
285 opts.bail(err)
286 return err
287 default:
288 return next(sesh)
289 }
290 }
291 }
292}