repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pico
Eric Bower  ·  2026-01-25

file_handler.go

  1package pico
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/picosh/pico/pkg/db"
 15	"github.com/picosh/pico/pkg/pssh"
 16	sendutils "github.com/picosh/pico/pkg/send/utils"
 17	"github.com/picosh/pico/pkg/shared"
 18	"golang.org/x/crypto/ssh"
 19)
 20
 21type UploadHandler struct {
 22	DBPool db.DB
 23	Cfg    *shared.ConfigSite
 24}
 25
 26func NewUploadHandler(dbpool db.DB, cfg *shared.ConfigSite) *UploadHandler {
 27	return &UploadHandler{
 28		DBPool: dbpool,
 29		Cfg:    cfg,
 30	}
 31}
 32
 33func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualFile, string, error) {
 34	keys, err := h.DBPool.FindKeysByUser(user)
 35	text := ""
 36	var modTime time.Time
 37	for _, pk := range keys {
 38		text += fmt.Sprintf("%s %s\n", pk.Key, pk.Name)
 39		modTime = *pk.CreatedAt
 40	}
 41	if err != nil {
 42		return nil, "", err
 43	}
 44	fileInfo := &sendutils.VirtualFile{
 45		FName:    "authorized_keys",
 46		FIsDir:   false,
 47		FSize:    int64(len(text)),
 48		FModTime: modTime,
 49	}
 50	return fileInfo, text, nil
 51}
 52
 53func (h *UploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
 54	return errors.New("unsupported")
 55}
 56
 57func (h *UploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
 58	logger := pssh.GetLogger(s)
 59	user := pssh.GetUser(s)
 60
 61	if user == nil {
 62		err := fmt.Errorf("could not get user from ctx")
 63		logger.Error("error getting user from ctx", "err", err)
 64		return nil, nil, err
 65	}
 66
 67	cleanFilename := filepath.Base(entry.Filepath)
 68
 69	if cleanFilename == "" || cleanFilename == "." {
 70		return nil, nil, os.ErrNotExist
 71	}
 72
 73	if cleanFilename == "authorized_keys" {
 74		fileInfo, text, err := h.getAuthorizedKeyFile(user)
 75		if err != nil {
 76			return nil, nil, err
 77		}
 78		reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(text))
 79		return fileInfo, reader, nil
 80	}
 81
 82	return nil, nil, os.ErrNotExist
 83}
 84
 85func (h *UploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
 86	var fileList []os.FileInfo
 87
 88	logger := pssh.GetLogger(s)
 89	user := pssh.GetUser(s)
 90
 91	if user == nil {
 92		err := fmt.Errorf("could not get user from ctx")
 93		logger.Error("error getting user from ctx", "err", err)
 94		return fileList, err
 95	}
 96
 97	cleanFilename := filepath.Base(fpath)
 98
 99	if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
100		name := cleanFilename
101		if name == "" {
102			name = "/"
103		}
104
105		fileList = append(fileList, &sendutils.VirtualFile{
106			FName:  name,
107			FIsDir: true,
108		})
109
110		flist, _, err := h.getAuthorizedKeyFile(user)
111		if err != nil {
112			return fileList, err
113		}
114		fileList = append(fileList, flist)
115	} else {
116		if cleanFilename == "authorized_keys" {
117			flist, _, err := h.getAuthorizedKeyFile(user)
118			if err != nil {
119				return fileList, err
120			}
121			fileList = append(fileList, flist)
122		}
123	}
124
125	return fileList, nil
126}
127
128func (h *UploadHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
129	return pssh.GetLogger(s)
130}
131
132func (h *UploadHandler) Validate(s *pssh.SSHServerConnSession) error {
133	var err error
134	key, err := sendutils.KeyText(s)
135	if err != nil {
136		return fmt.Errorf("key not found")
137	}
138
139	user, err := h.DBPool.FindUserByKey(s.User(), key)
140	if err != nil {
141		return err
142	}
143
144	if user.Name == "" {
145		return fmt.Errorf("must have username set")
146	}
147
148	s.Permissions().Extensions["user_id"] = user.ID
149	return nil
150}
151
152type KeyWithId struct {
153	Pk      ssh.PublicKey
154	ID      string
155	Comment string
156}
157
158type KeyDiffResult struct {
159	Add    []KeyWithId
160	Rm     []string
161	Update []KeyWithId
162}
163
164func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []KeyWithId) KeyDiffResult {
165	update := []KeyWithId{}
166	add := []KeyWithId{}
167	for _, nk := range nextKeys {
168		found := false
169		for _, ck := range curKeys {
170			if pssh.KeysEqual(nk.Pk, ck.Pk) {
171				found = true
172
173				// update the comment field
174				if nk.Comment != ck.Comment {
175					ck.Comment = nk.Comment
176					update = append(update, ck)
177				}
178				break
179			}
180		}
181		if !found {
182			add = append(add, nk)
183		}
184	}
185
186	rm := []string{}
187	for _, ck := range curKeys {
188		// we never want to remove the key that's in the current ssh session
189		// in an effort to avoid mistakenly removing their current key
190		if pssh.KeysEqual(ck.Pk, keyInUse) {
191			continue
192		}
193
194		found := false
195		for _, nk := range nextKeys {
196			if pssh.KeysEqual(ck.Pk, nk.Pk) {
197				found = true
198				break
199			}
200		}
201		if !found {
202			rm = append(rm, ck.ID)
203		}
204	}
205
206	return KeyDiffResult{
207		Add:    add,
208		Rm:     rm,
209		Update: update,
210	}
211}
212
213func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s *pssh.SSHServerConnSession) error {
214	logger.Info("processing new authorized_keys")
215	dbpool := h.DBPool
216
217	curKeysStr, err := dbpool.FindKeysByUser(user)
218	if err != nil {
219		return err
220	}
221
222	splitKeys := bytes.Split(text, []byte{'\n'})
223	nextKeys := []KeyWithId{}
224	for _, pk := range splitKeys {
225		key, comment, _, _, err := ssh.ParseAuthorizedKey(bytes.TrimSpace(pk))
226		if err != nil {
227			continue
228		}
229		nextKeys = append(nextKeys, KeyWithId{Pk: key, Comment: comment})
230	}
231
232	curKeys := []KeyWithId{}
233	for _, pk := range curKeysStr {
234		key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
235		if err != nil {
236			continue
237		}
238		curKeys = append(curKeys, KeyWithId{Pk: key, ID: pk.ID, Comment: pk.Name})
239	}
240
241	diff := authorizedKeysDiff(s.PublicKey(), curKeys, nextKeys)
242
243	for _, pk := range diff.Add {
244		key := shared.KeyForKeyText(pk.Pk)
245
246		_, _ = fmt.Fprintf(s.Stderr(), "adding pubkey (%s)\n", key)
247		logger.Info("adding pubkey", "pubkey", key)
248
249		err = dbpool.InsertPublicKey(user.ID, key, pk.Comment)
250		if err != nil {
251			_, _ = fmt.Fprintf(s.Stderr(), "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
252			logger.Error("could not insert pubkey", "err", err.Error())
253		}
254	}
255
256	for _, pk := range diff.Update {
257		key := shared.KeyForKeyText(pk.Pk)
258
259		_, _ = fmt.Fprintf(s.Stderr(), "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
260		logger.Info(
261			"updating pubkey with comment",
262			"pubkey", key,
263			"comment", pk.Comment,
264		)
265
266		_, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
267		if err != nil {
268			_, _ = fmt.Fprintf(s.Stderr(), "error: could not update pubkey: %s (%s)\n", err.Error(), key)
269			logger.Error("could not update pubkey", "err", err.Error(), "key", key)
270		}
271	}
272
273	if len(diff.Rm) > 0 {
274		_, _ = fmt.Fprintf(s.Stderr(), "removing pubkeys: %s\n", diff.Rm)
275		logger.Info("removing pubkeys", "pubkeys", diff.Rm)
276
277		err = dbpool.RemoveKeys(diff.Rm)
278		if err != nil {
279			_, _ = fmt.Fprintf(s.Stderr(), "error: could not rm pubkeys: %s\n", err.Error())
280			logger.Error("could not remove pubkey", "err", err.Error())
281		}
282	}
283
284	return nil
285}
286
287func (h *UploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
288	logger := pssh.GetLogger(s)
289	user := pssh.GetUser(s)
290
291	if entry == nil || entry.Reader == nil {
292		return "", fmt.Errorf("file entry is nil, no file to update")
293	}
294
295	if user == nil {
296		err := fmt.Errorf("could not get user from ctx")
297		logger.Error("error getting user from ctx", "err", err)
298		return "", err
299	}
300
301	filename := filepath.Base(entry.Filepath)
302	logger = logger.With(
303		"user", user.Name,
304		"filename", filename,
305	)
306
307	var text []byte
308	if b, err := io.ReadAll(entry.Reader); err == nil {
309		text = b
310	}
311
312	if filename == "authorized_keys" {
313		err := h.ProcessAuthorizedKeys(text, logger, user, s)
314		if err != nil {
315			return "", err
316		}
317	} else {
318		return "", fmt.Errorf("validation error: invalid file, received %s", entry.Filepath)
319	}
320
321	return "", nil
322}