Antonio Mika
·
2025-03-12
file_handler.go
1package pico
2
3import (
4 "bytes"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/picosh/pico/pkg/db"
15 "github.com/picosh/pico/pkg/pssh"
16 sendutils "github.com/picosh/pico/pkg/send/utils"
17 "github.com/picosh/pico/pkg/shared"
18 "github.com/picosh/utils"
19 "golang.org/x/crypto/ssh"
20)
21
22type UploadHandler struct {
23 DBPool db.DB
24 Cfg *shared.ConfigSite
25}
26
27func NewUploadHandler(dbpool db.DB, cfg *shared.ConfigSite) *UploadHandler {
28 return &UploadHandler{
29 DBPool: dbpool,
30 Cfg: cfg,
31 }
32}
33
34func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualFile, string, error) {
35 keys, err := h.DBPool.FindKeysForUser(user)
36 text := ""
37 var modTime time.Time
38 for _, pk := range keys {
39 text += fmt.Sprintf("%s %s\n", pk.Key, pk.Name)
40 modTime = *pk.CreatedAt
41 }
42 if err != nil {
43 return nil, "", err
44 }
45 fileInfo := &sendutils.VirtualFile{
46 FName: "authorized_keys",
47 FIsDir: false,
48 FSize: int64(len(text)),
49 FModTime: modTime,
50 }
51 return fileInfo, text, nil
52}
53
54func (h *UploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
55 return errors.New("unsupported")
56}
57
58func (h *UploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
59 logger := pssh.GetLogger(s)
60 user := pssh.GetUser(s)
61
62 if user == nil {
63 err := fmt.Errorf("could not get user from ctx")
64 logger.Error("error getting user from ctx", "err", err)
65 return nil, nil, err
66 }
67
68 cleanFilename := filepath.Base(entry.Filepath)
69
70 if cleanFilename == "" || cleanFilename == "." {
71 return nil, nil, os.ErrNotExist
72 }
73
74 if cleanFilename == "authorized_keys" {
75 fileInfo, text, err := h.getAuthorizedKeyFile(user)
76 if err != nil {
77 return nil, nil, err
78 }
79 reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(text))
80 return fileInfo, reader, nil
81 }
82
83 return nil, nil, os.ErrNotExist
84}
85
86func (h *UploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
87 var fileList []os.FileInfo
88
89 logger := pssh.GetLogger(s)
90 user := pssh.GetUser(s)
91
92 if user == nil {
93 err := fmt.Errorf("could not get user from ctx")
94 logger.Error("error getting user from ctx", "err", err)
95 return fileList, err
96 }
97
98 cleanFilename := filepath.Base(fpath)
99
100 if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
101 name := cleanFilename
102 if name == "" {
103 name = "/"
104 }
105
106 fileList = append(fileList, &sendutils.VirtualFile{
107 FName: name,
108 FIsDir: true,
109 })
110
111 flist, _, err := h.getAuthorizedKeyFile(user)
112 if err != nil {
113 return fileList, err
114 }
115 fileList = append(fileList, flist)
116 } else {
117 if cleanFilename == "authorized_keys" {
118 flist, _, err := h.getAuthorizedKeyFile(user)
119 if err != nil {
120 return fileList, err
121 }
122 fileList = append(fileList, flist)
123 }
124 }
125
126 return fileList, nil
127}
128
129func (h *UploadHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
130 return pssh.GetLogger(s)
131}
132
133func (h *UploadHandler) Validate(s *pssh.SSHServerConnSession) error {
134 var err error
135 key, err := sendutils.KeyText(s)
136 if err != nil {
137 return fmt.Errorf("key not found")
138 }
139
140 user, err := h.DBPool.FindUserForKey(s.User(), key)
141 if err != nil {
142 return err
143 }
144
145 if user.Name == "" {
146 return fmt.Errorf("must have username set")
147 }
148
149 s.Permissions().Extensions["user_id"] = user.ID
150 return nil
151}
152
153type KeyWithId struct {
154 Pk ssh.PublicKey
155 ID string
156 Comment string
157}
158
159type KeyDiffResult struct {
160 Add []KeyWithId
161 Rm []string
162 Update []KeyWithId
163}
164
165func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []KeyWithId) KeyDiffResult {
166 update := []KeyWithId{}
167 add := []KeyWithId{}
168 for _, nk := range nextKeys {
169 found := false
170 for _, ck := range curKeys {
171 if pssh.KeysEqual(nk.Pk, ck.Pk) {
172 found = true
173
174 // update the comment field
175 if nk.Comment != ck.Comment {
176 ck.Comment = nk.Comment
177 update = append(update, ck)
178 }
179 break
180 }
181 }
182 if !found {
183 add = append(add, nk)
184 }
185 }
186
187 rm := []string{}
188 for _, ck := range curKeys {
189 // we never want to remove the key that's in the current ssh session
190 // in an effort to avoid mistakenly removing their current key
191 if pssh.KeysEqual(ck.Pk, keyInUse) {
192 continue
193 }
194
195 found := false
196 for _, nk := range nextKeys {
197 if pssh.KeysEqual(ck.Pk, nk.Pk) {
198 found = true
199 break
200 }
201 }
202 if !found {
203 rm = append(rm, ck.ID)
204 }
205 }
206
207 return KeyDiffResult{
208 Add: add,
209 Rm: rm,
210 Update: update,
211 }
212}
213
214func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s *pssh.SSHServerConnSession) error {
215 logger.Info("processing new authorized_keys")
216 dbpool := h.DBPool
217
218 curKeysStr, err := dbpool.FindKeysForUser(user)
219 if err != nil {
220 return err
221 }
222
223 splitKeys := bytes.Split(text, []byte{'\n'})
224 nextKeys := []KeyWithId{}
225 for _, pk := range splitKeys {
226 key, comment, _, _, err := ssh.ParseAuthorizedKey(bytes.TrimSpace(pk))
227 if err != nil {
228 continue
229 }
230 nextKeys = append(nextKeys, KeyWithId{Pk: key, Comment: comment})
231 }
232
233 curKeys := []KeyWithId{}
234 for _, pk := range curKeysStr {
235 key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
236 if err != nil {
237 continue
238 }
239 curKeys = append(curKeys, KeyWithId{Pk: key, ID: pk.ID, Comment: pk.Name})
240 }
241
242 diff := authorizedKeysDiff(s.PublicKey(), curKeys, nextKeys)
243
244 for _, pk := range diff.Add {
245 key := utils.KeyForKeyText(pk.Pk)
246
247 fmt.Fprintf(s.Stderr(), "adding pubkey (%s)\n", key)
248 logger.Info("adding pubkey", "pubkey", key)
249
250 err = dbpool.InsertPublicKey(user.ID, key, pk.Comment, nil)
251 if err != nil {
252 fmt.Fprintf(s.Stderr(), "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
253 logger.Error("could not insert pubkey", "err", err.Error())
254 }
255 }
256
257 for _, pk := range diff.Update {
258 key := utils.KeyForKeyText(pk.Pk)
259
260 fmt.Fprintf(s.Stderr(), "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
261 logger.Info(
262 "updating pubkey with comment",
263 "pubkey", key,
264 "comment", pk.Comment,
265 )
266
267 _, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
268 if err != nil {
269 fmt.Fprintf(s.Stderr(), "error: could not update pubkey: %s (%s)\n", err.Error(), key)
270 logger.Error("could not update pubkey", "err", err.Error(), "key", key)
271 }
272 }
273
274 if len(diff.Rm) > 0 {
275 fmt.Fprintf(s.Stderr(), "removing pubkeys: %s\n", diff.Rm)
276 logger.Info("removing pubkeys", "pubkeys", diff.Rm)
277
278 err = dbpool.RemoveKeys(diff.Rm)
279 if err != nil {
280 fmt.Fprintf(s.Stderr(), "error: could not rm pubkeys: %s\n", err.Error())
281 logger.Error("could not remove pubkey", "err", err.Error())
282 }
283 }
284
285 return nil
286}
287
288func (h *UploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
289 logger := pssh.GetLogger(s)
290 user := pssh.GetUser(s)
291
292 if user == nil {
293 err := fmt.Errorf("could not get user from ctx")
294 logger.Error("error getting user from ctx", "err", err)
295 return "", err
296 }
297
298 filename := filepath.Base(entry.Filepath)
299 logger = logger.With(
300 "user", user.Name,
301 "filename", filename,
302 )
303
304 var text []byte
305 if b, err := io.ReadAll(entry.Reader); err == nil {
306 text = b
307 }
308
309 if filename == "authorized_keys" {
310 err := h.ProcessAuthorizedKeys(text, logger, user, s)
311 if err != nil {
312 return "", err
313 }
314 } else {
315 return "", fmt.Errorf("validation error: invalid file, received %s", entry.Filepath)
316 }
317
318 return "", nil
319}