Eric Bower
·
2026-01-25
file_handler.go
1package pico
2
3import (
4 "bytes"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/picosh/pico/pkg/db"
15 "github.com/picosh/pico/pkg/pssh"
16 sendutils "github.com/picosh/pico/pkg/send/utils"
17 "github.com/picosh/pico/pkg/shared"
18 "golang.org/x/crypto/ssh"
19)
20
21type UploadHandler struct {
22 DBPool db.DB
23 Cfg *shared.ConfigSite
24}
25
26func NewUploadHandler(dbpool db.DB, cfg *shared.ConfigSite) *UploadHandler {
27 return &UploadHandler{
28 DBPool: dbpool,
29 Cfg: cfg,
30 }
31}
32
33func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualFile, string, error) {
34 keys, err := h.DBPool.FindKeysByUser(user)
35 text := ""
36 var modTime time.Time
37 for _, pk := range keys {
38 text += fmt.Sprintf("%s %s\n", pk.Key, pk.Name)
39 modTime = *pk.CreatedAt
40 }
41 if err != nil {
42 return nil, "", err
43 }
44 fileInfo := &sendutils.VirtualFile{
45 FName: "authorized_keys",
46 FIsDir: false,
47 FSize: int64(len(text)),
48 FModTime: modTime,
49 }
50 return fileInfo, text, nil
51}
52
53func (h *UploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
54 return errors.New("unsupported")
55}
56
57func (h *UploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
58 logger := pssh.GetLogger(s)
59 user := pssh.GetUser(s)
60
61 if user == nil {
62 err := fmt.Errorf("could not get user from ctx")
63 logger.Error("error getting user from ctx", "err", err)
64 return nil, nil, err
65 }
66
67 cleanFilename := filepath.Base(entry.Filepath)
68
69 if cleanFilename == "" || cleanFilename == "." {
70 return nil, nil, os.ErrNotExist
71 }
72
73 if cleanFilename == "authorized_keys" {
74 fileInfo, text, err := h.getAuthorizedKeyFile(user)
75 if err != nil {
76 return nil, nil, err
77 }
78 reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(text))
79 return fileInfo, reader, nil
80 }
81
82 return nil, nil, os.ErrNotExist
83}
84
85func (h *UploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
86 var fileList []os.FileInfo
87
88 logger := pssh.GetLogger(s)
89 user := pssh.GetUser(s)
90
91 if user == nil {
92 err := fmt.Errorf("could not get user from ctx")
93 logger.Error("error getting user from ctx", "err", err)
94 return fileList, err
95 }
96
97 cleanFilename := filepath.Base(fpath)
98
99 if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
100 name := cleanFilename
101 if name == "" {
102 name = "/"
103 }
104
105 fileList = append(fileList, &sendutils.VirtualFile{
106 FName: name,
107 FIsDir: true,
108 })
109
110 flist, _, err := h.getAuthorizedKeyFile(user)
111 if err != nil {
112 return fileList, err
113 }
114 fileList = append(fileList, flist)
115 } else {
116 if cleanFilename == "authorized_keys" {
117 flist, _, err := h.getAuthorizedKeyFile(user)
118 if err != nil {
119 return fileList, err
120 }
121 fileList = append(fileList, flist)
122 }
123 }
124
125 return fileList, nil
126}
127
128func (h *UploadHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
129 return pssh.GetLogger(s)
130}
131
132func (h *UploadHandler) Validate(s *pssh.SSHServerConnSession) error {
133 var err error
134 key, err := sendutils.KeyText(s)
135 if err != nil {
136 return fmt.Errorf("key not found")
137 }
138
139 user, err := h.DBPool.FindUserByKey(s.User(), key)
140 if err != nil {
141 return err
142 }
143
144 if user.Name == "" {
145 return fmt.Errorf("must have username set")
146 }
147
148 s.Permissions().Extensions["user_id"] = user.ID
149 return nil
150}
151
152type KeyWithId struct {
153 Pk ssh.PublicKey
154 ID string
155 Comment string
156}
157
158type KeyDiffResult struct {
159 Add []KeyWithId
160 Rm []string
161 Update []KeyWithId
162}
163
164func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []KeyWithId) KeyDiffResult {
165 update := []KeyWithId{}
166 add := []KeyWithId{}
167 for _, nk := range nextKeys {
168 found := false
169 for _, ck := range curKeys {
170 if pssh.KeysEqual(nk.Pk, ck.Pk) {
171 found = true
172
173 // update the comment field
174 if nk.Comment != ck.Comment {
175 ck.Comment = nk.Comment
176 update = append(update, ck)
177 }
178 break
179 }
180 }
181 if !found {
182 add = append(add, nk)
183 }
184 }
185
186 rm := []string{}
187 for _, ck := range curKeys {
188 // we never want to remove the key that's in the current ssh session
189 // in an effort to avoid mistakenly removing their current key
190 if pssh.KeysEqual(ck.Pk, keyInUse) {
191 continue
192 }
193
194 found := false
195 for _, nk := range nextKeys {
196 if pssh.KeysEqual(ck.Pk, nk.Pk) {
197 found = true
198 break
199 }
200 }
201 if !found {
202 rm = append(rm, ck.ID)
203 }
204 }
205
206 return KeyDiffResult{
207 Add: add,
208 Rm: rm,
209 Update: update,
210 }
211}
212
213func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s *pssh.SSHServerConnSession) error {
214 logger.Info("processing new authorized_keys")
215 dbpool := h.DBPool
216
217 curKeysStr, err := dbpool.FindKeysByUser(user)
218 if err != nil {
219 return err
220 }
221
222 splitKeys := bytes.Split(text, []byte{'\n'})
223 nextKeys := []KeyWithId{}
224 for _, pk := range splitKeys {
225 key, comment, _, _, err := ssh.ParseAuthorizedKey(bytes.TrimSpace(pk))
226 if err != nil {
227 continue
228 }
229 nextKeys = append(nextKeys, KeyWithId{Pk: key, Comment: comment})
230 }
231
232 curKeys := []KeyWithId{}
233 for _, pk := range curKeysStr {
234 key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
235 if err != nil {
236 continue
237 }
238 curKeys = append(curKeys, KeyWithId{Pk: key, ID: pk.ID, Comment: pk.Name})
239 }
240
241 diff := authorizedKeysDiff(s.PublicKey(), curKeys, nextKeys)
242
243 for _, pk := range diff.Add {
244 key := shared.KeyForKeyText(pk.Pk)
245
246 _, _ = fmt.Fprintf(s.Stderr(), "adding pubkey (%s)\n", key)
247 logger.Info("adding pubkey", "pubkey", key)
248
249 err = dbpool.InsertPublicKey(user.ID, key, pk.Comment)
250 if err != nil {
251 _, _ = fmt.Fprintf(s.Stderr(), "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
252 logger.Error("could not insert pubkey", "err", err.Error())
253 }
254 }
255
256 for _, pk := range diff.Update {
257 key := shared.KeyForKeyText(pk.Pk)
258
259 _, _ = fmt.Fprintf(s.Stderr(), "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
260 logger.Info(
261 "updating pubkey with comment",
262 "pubkey", key,
263 "comment", pk.Comment,
264 )
265
266 _, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
267 if err != nil {
268 _, _ = fmt.Fprintf(s.Stderr(), "error: could not update pubkey: %s (%s)\n", err.Error(), key)
269 logger.Error("could not update pubkey", "err", err.Error(), "key", key)
270 }
271 }
272
273 if len(diff.Rm) > 0 {
274 _, _ = fmt.Fprintf(s.Stderr(), "removing pubkeys: %s\n", diff.Rm)
275 logger.Info("removing pubkeys", "pubkeys", diff.Rm)
276
277 err = dbpool.RemoveKeys(diff.Rm)
278 if err != nil {
279 _, _ = fmt.Fprintf(s.Stderr(), "error: could not rm pubkeys: %s\n", err.Error())
280 logger.Error("could not remove pubkey", "err", err.Error())
281 }
282 }
283
284 return nil
285}
286
287func (h *UploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
288 logger := pssh.GetLogger(s)
289 user := pssh.GetUser(s)
290
291 if entry == nil || entry.Reader == nil {
292 return "", fmt.Errorf("file entry is nil, no file to update")
293 }
294
295 if user == nil {
296 err := fmt.Errorf("could not get user from ctx")
297 logger.Error("error getting user from ctx", "err", err)
298 return "", err
299 }
300
301 filename := filepath.Base(entry.Filepath)
302 logger = logger.With(
303 "user", user.Name,
304 "filename", filename,
305 )
306
307 var text []byte
308 if b, err := io.ReadAll(entry.Reader); err == nil {
309 text = b
310 }
311
312 if filename == "authorized_keys" {
313 err := h.ProcessAuthorizedKeys(text, logger, user, s)
314 if err != nil {
315 return "", err
316 }
317 } else {
318 return "", fmt.Errorf("validation error: invalid file, received %s", entry.Filepath)
319 }
320
321 return "", nil
322}