Eric Bower
·
2025-05-24
cli_middleware.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8
9 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
10 "github.com/picosh/pico/pkg/db"
11 "github.com/picosh/pico/pkg/pssh"
12 sendutils "github.com/picosh/pico/pkg/send/utils"
13)
14
15func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
16 userID, ok := s.Conn.Permissions.Extensions["user_id"]
17 if !ok {
18 return nil, fmt.Errorf("`user_id` extension not found")
19 }
20 return dbpool.FindUser(userID)
21}
22
23type arrayFlags []string
24
25func (i *arrayFlags) String() string {
26 return "array flags"
27}
28
29func (i *arrayFlags) Set(value string) error {
30 *i = append(*i, value)
31 return nil
32}
33
34func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
35 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
36 cmd.SetOutput(sesh)
37 write := cmd.Bool("write", false, "apply changes")
38 return cmd, write
39}
40
41func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
42 _ = cmd.Parse(cmdArgs)
43
44 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
45 cmd.Usage()
46 return false
47 }
48 return true
49}
50
51func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
52 dbpool := handler.Cfg.DB
53 log := handler.Cfg.Logger
54 cfg := handler.Cfg
55 store := handler.Cfg.Storage
56
57 return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
58 return func(sesh *pssh.SSHServerConnSession) error {
59 args := sesh.Command()
60
61 // default width and height when no pty
62 width := 100
63 height := 24
64 pty, _, ok := sesh.Pty()
65 if ok {
66 width = pty.Window.Width
67 height = pty.Window.Height
68 }
69
70 user, err := getUser(sesh, dbpool)
71 if err != nil {
72 sendutils.ErrorHandler(sesh, err)
73 return err
74 }
75
76 // renderer := bm.MakeRenderer(sesh)
77 // renderer.SetColorProfile(termenv.TrueColor)
78 // styles := common.DefaultStyles(renderer)
79
80 opts := Cmd{
81 Session: sesh,
82 User: user,
83 Store: store,
84 Log: log,
85 Dbpool: dbpool,
86 Write: false,
87 // Styles: styles,
88 Width: width,
89 Height: height,
90 Cfg: handler.Cfg,
91 }
92
93 if len(args) == 0 {
94 opts.help()
95 return nil
96 }
97
98 cmd := strings.TrimSpace(args[0])
99 if len(args) == 1 {
100 switch cmd {
101 case "help":
102 opts.help()
103 return nil
104 case "stats":
105 err := opts.stats(cfg.MaxSize)
106 opts.bail(err)
107 return err
108 case "ls":
109 err := opts.ls()
110 opts.bail(err)
111 return err
112 case "cache-all":
113 opts.Write = true
114 err := opts.cacheAll()
115 opts.notice()
116 opts.bail(err)
117 return err
118 default:
119 return next(sesh)
120 }
121 }
122
123 projectName := strings.TrimSpace(args[1])
124 cmdArgs := args[2:]
125 log.Info(
126 "pgs middleware detected command",
127 "args", args,
128 "cmd", cmd,
129 "projectName", projectName,
130 "cmdArgs", cmdArgs,
131 )
132
133 switch cmd {
134 case "fzf":
135 err := opts.fzf(projectName)
136 opts.bail(err)
137 return err
138 case "link":
139 linkCmd, write := flagSet("link", sesh)
140 linkTo := linkCmd.String("to", "", "symbolic link to this project")
141 if !flagCheck(linkCmd, projectName, cmdArgs) {
142 return nil
143 }
144 opts.Write = *write
145
146 if *linkTo == "" {
147 err := fmt.Errorf(
148 "must provide `--to` flag",
149 )
150 opts.bail(err)
151 return err
152 }
153
154 err := opts.link(projectName, *linkTo)
155 opts.notice()
156 if err != nil {
157 opts.bail(err)
158 }
159 return err
160 case "unlink":
161 unlinkCmd, write := flagSet("unlink", sesh)
162 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
163 return nil
164 }
165 opts.Write = *write
166
167 err := opts.unlink(projectName)
168 opts.notice()
169 opts.bail(err)
170 return err
171 case "depends":
172 err := opts.depends(projectName)
173 opts.bail(err)
174 return err
175 case "retain":
176 retainCmd, write := flagSet("retain", sesh)
177 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
178 if !flagCheck(retainCmd, projectName, cmdArgs) {
179 return nil
180 }
181 opts.Write = *write
182
183 err := opts.prune(projectName, *retainNum)
184 opts.notice()
185 opts.bail(err)
186 return err
187 case "prune":
188 pruneCmd, write := flagSet("prune", sesh)
189 if !flagCheck(pruneCmd, projectName, cmdArgs) {
190 return nil
191 }
192 opts.Write = *write
193
194 err := opts.prune(projectName, 0)
195 opts.notice()
196 opts.bail(err)
197 return err
198 case "rm":
199 rmCmd, write := flagSet("rm", sesh)
200 if !flagCheck(rmCmd, projectName, cmdArgs) {
201 return nil
202 }
203 opts.Write = *write
204
205 err := opts.rm(projectName)
206 opts.notice()
207 opts.bail(err)
208 return err
209 case "cache":
210 cacheCmd, write := flagSet("cache", sesh)
211 if !flagCheck(cacheCmd, projectName, cmdArgs) {
212 return nil
213 }
214 opts.Write = *write
215
216 err := opts.cache(projectName)
217 opts.notice()
218 opts.bail(err)
219 return err
220 case "acl":
221 aclCmd, write := flagSet("acl", sesh)
222 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
223 var acls arrayFlags
224 aclCmd.Var(
225 &acls,
226 "acl",
227 "list of pico usernames or sha256 public keys, delimited by commas",
228 )
229 if !flagCheck(aclCmd, projectName, cmdArgs) {
230 return nil
231 }
232 opts.Write = *write
233
234 if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
235 err := fmt.Errorf(
236 "acl type must be one of the following: [public, pubkeys, pico], found %s",
237 *aclType,
238 )
239 opts.bail(err)
240 return err
241 }
242
243 err := opts.acl(projectName, *aclType, acls)
244 opts.notice()
245 opts.bail(err)
246 return err
247 default:
248 return next(sesh)
249 }
250 }
251 }
252}