Eric Bower
ยท
2025-03-28
cli_middleware.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8
9 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
10 "github.com/picosh/pico/pkg/db"
11 "github.com/picosh/pico/pkg/pssh"
12 sendutils "github.com/picosh/pico/pkg/send/utils"
13)
14
15func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
16 userID, ok := s.Conn.Permissions.Extensions["user_id"]
17 if !ok {
18 return nil, fmt.Errorf("`user_id` extension not found")
19 }
20 return dbpool.FindUser(userID)
21}
22
23type arrayFlags []string
24
25func (i *arrayFlags) String() string {
26 return "array flags"
27}
28
29func (i *arrayFlags) Set(value string) error {
30 *i = append(*i, value)
31 return nil
32}
33
34func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
35 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
36 cmd.SetOutput(sesh)
37 write := cmd.Bool("write", false, "apply changes")
38 return cmd, write
39}
40
41func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
42 _ = cmd.Parse(cmdArgs)
43
44 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
45 cmd.Usage()
46 return false
47 }
48 return true
49}
50
51func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
52 dbpool := handler.Cfg.DB
53 log := handler.Cfg.Logger
54 cfg := handler.Cfg
55 store := handler.Cfg.Storage
56
57 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
58 return func(sesh *pssh.SSHServerConnSession) error {
59 args := sesh.Command()
60
61 // default width and height when no pty
62 width := 100
63 height := 24
64 pty, _, ok := sesh.Pty()
65 if ok {
66 width = pty.Window.Width
67 height = pty.Window.Height
68 }
69
70 user, err := getUser(sesh, dbpool)
71 if err != nil {
72 sendutils.ErrorHandler(sesh, err)
73 return err
74 }
75
76 // renderer := bm.MakeRenderer(sesh)
77 // renderer.SetColorProfile(termenv.TrueColor)
78 // styles := common.DefaultStyles(renderer)
79
80 opts := Cmd{
81 Session: sesh,
82 User: user,
83 Store: store,
84 Log: log,
85 Dbpool: dbpool,
86 Write: false,
87 // Styles: styles,
88 Width: width,
89 Height: height,
90 Cfg: handler.Cfg,
91 }
92
93 if len(args) == 0 {
94 opts.help()
95 return nil
96 }
97
98 cmd := strings.TrimSpace(args[0])
99 if len(args) == 1 {
100 if cmd == "help" {
101 opts.help()
102 return nil
103 } else if cmd == "stats" {
104 err := opts.stats(cfg.MaxSize)
105 opts.bail(err)
106 return err
107 } else if cmd == "ls" {
108 err := opts.ls()
109 opts.bail(err)
110 return err
111 } else if cmd == "cache-all" {
112 opts.Write = true
113 err := opts.cacheAll()
114 opts.notice()
115 opts.bail(err)
116 return err
117 } else {
118 return next(sesh)
119 }
120 }
121
122 projectName := strings.TrimSpace(args[1])
123 cmdArgs := args[2:]
124 log.Info(
125 "pgs middleware detected command",
126 "args", args,
127 "cmd", cmd,
128 "projectName", projectName,
129 "cmdArgs", cmdArgs,
130 )
131
132 if cmd == "fzf" {
133 err := opts.fzf(projectName)
134 opts.bail(err)
135 return err
136 } else if cmd == "link" {
137 linkCmd, write := flagSet("link", sesh)
138 linkTo := linkCmd.String("to", "", "symbolic link to this project")
139 if !flagCheck(linkCmd, projectName, cmdArgs) {
140 return nil
141 }
142 opts.Write = *write
143
144 if *linkTo == "" {
145 err := fmt.Errorf(
146 "must provide `--to` flag",
147 )
148 opts.bail(err)
149 return err
150 }
151
152 err := opts.link(projectName, *linkTo)
153 opts.notice()
154 if err != nil {
155 opts.bail(err)
156 }
157 return err
158 } else if cmd == "unlink" {
159 unlinkCmd, write := flagSet("unlink", sesh)
160 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
161 return nil
162 }
163 opts.Write = *write
164
165 err := opts.unlink(projectName)
166 opts.notice()
167 opts.bail(err)
168 return err
169 } else if cmd == "depends" {
170 err := opts.depends(projectName)
171 opts.bail(err)
172 return err
173 } else if cmd == "retain" {
174 retainCmd, write := flagSet("retain", sesh)
175 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
176 if !flagCheck(retainCmd, projectName, cmdArgs) {
177 return nil
178 }
179 opts.Write = *write
180
181 err := opts.prune(projectName, *retainNum)
182 opts.notice()
183 opts.bail(err)
184 return err
185 } else if cmd == "prune" {
186 pruneCmd, write := flagSet("prune", sesh)
187 if !flagCheck(pruneCmd, projectName, cmdArgs) {
188 return nil
189 }
190 opts.Write = *write
191
192 err := opts.prune(projectName, 0)
193 opts.notice()
194 opts.bail(err)
195 return err
196 } else if cmd == "rm" {
197 rmCmd, write := flagSet("rm", sesh)
198 if !flagCheck(rmCmd, projectName, cmdArgs) {
199 return nil
200 }
201 opts.Write = *write
202
203 err := opts.rm(projectName)
204 opts.notice()
205 opts.bail(err)
206 return err
207 } else if cmd == "cache" {
208 cacheCmd, write := flagSet("cache", sesh)
209 if !flagCheck(cacheCmd, projectName, cmdArgs) {
210 return nil
211 }
212 opts.Write = *write
213
214 err := opts.cache(projectName)
215 opts.notice()
216 opts.bail(err)
217 return err
218 } else if cmd == "acl" {
219 aclCmd, write := flagSet("acl", sesh)
220 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
221 var acls arrayFlags
222 aclCmd.Var(
223 &acls,
224 "acl",
225 "list of pico usernames or sha256 public keys, delimited by commas",
226 )
227 if !flagCheck(aclCmd, projectName, cmdArgs) {
228 return nil
229 }
230 opts.Write = *write
231
232 if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
233 err := fmt.Errorf(
234 "acl type must be one of the following: [public, pubkeys, pico], found %s",
235 *aclType,
236 )
237 opts.bail(err)
238 return err
239 }
240
241 err := opts.acl(projectName, *aclType, acls)
242 opts.notice()
243 opts.bail(err)
244 return err
245 } else {
246 return next(sesh)
247 }
248 }
249 }
250}