repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pico
Antonio Mika  ·  2025-03-12

file_handler.go

  1package pico
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/picosh/pico/pkg/db"
 15	"github.com/picosh/pico/pkg/pssh"
 16	sendutils "github.com/picosh/pico/pkg/send/utils"
 17	"github.com/picosh/pico/pkg/shared"
 18	"github.com/picosh/utils"
 19	"golang.org/x/crypto/ssh"
 20)
 21
 22type UploadHandler struct {
 23	DBPool db.DB
 24	Cfg    *shared.ConfigSite
 25}
 26
 27func NewUploadHandler(dbpool db.DB, cfg *shared.ConfigSite) *UploadHandler {
 28	return &UploadHandler{
 29		DBPool: dbpool,
 30		Cfg:    cfg,
 31	}
 32}
 33
 34func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualFile, string, error) {
 35	keys, err := h.DBPool.FindKeysForUser(user)
 36	text := ""
 37	var modTime time.Time
 38	for _, pk := range keys {
 39		text += fmt.Sprintf("%s %s\n", pk.Key, pk.Name)
 40		modTime = *pk.CreatedAt
 41	}
 42	if err != nil {
 43		return nil, "", err
 44	}
 45	fileInfo := &sendutils.VirtualFile{
 46		FName:    "authorized_keys",
 47		FIsDir:   false,
 48		FSize:    int64(len(text)),
 49		FModTime: modTime,
 50	}
 51	return fileInfo, text, nil
 52}
 53
 54func (h *UploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
 55	return errors.New("unsupported")
 56}
 57
 58func (h *UploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
 59	logger := pssh.GetLogger(s)
 60	user := pssh.GetUser(s)
 61
 62	if user == nil {
 63		err := fmt.Errorf("could not get user from ctx")
 64		logger.Error("error getting user from ctx", "err", err)
 65		return nil, nil, err
 66	}
 67
 68	cleanFilename := filepath.Base(entry.Filepath)
 69
 70	if cleanFilename == "" || cleanFilename == "." {
 71		return nil, nil, os.ErrNotExist
 72	}
 73
 74	if cleanFilename == "authorized_keys" {
 75		fileInfo, text, err := h.getAuthorizedKeyFile(user)
 76		if err != nil {
 77			return nil, nil, err
 78		}
 79		reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(text))
 80		return fileInfo, reader, nil
 81	}
 82
 83	return nil, nil, os.ErrNotExist
 84}
 85
 86func (h *UploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
 87	var fileList []os.FileInfo
 88
 89	logger := pssh.GetLogger(s)
 90	user := pssh.GetUser(s)
 91
 92	if user == nil {
 93		err := fmt.Errorf("could not get user from ctx")
 94		logger.Error("error getting user from ctx", "err", err)
 95		return fileList, err
 96	}
 97
 98	cleanFilename := filepath.Base(fpath)
 99
100	if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
101		name := cleanFilename
102		if name == "" {
103			name = "/"
104		}
105
106		fileList = append(fileList, &sendutils.VirtualFile{
107			FName:  name,
108			FIsDir: true,
109		})
110
111		flist, _, err := h.getAuthorizedKeyFile(user)
112		if err != nil {
113			return fileList, err
114		}
115		fileList = append(fileList, flist)
116	} else {
117		if cleanFilename == "authorized_keys" {
118			flist, _, err := h.getAuthorizedKeyFile(user)
119			if err != nil {
120				return fileList, err
121			}
122			fileList = append(fileList, flist)
123		}
124	}
125
126	return fileList, nil
127}
128
129func (h *UploadHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
130	return pssh.GetLogger(s)
131}
132
133func (h *UploadHandler) Validate(s *pssh.SSHServerConnSession) error {
134	var err error
135	key, err := sendutils.KeyText(s)
136	if err != nil {
137		return fmt.Errorf("key not found")
138	}
139
140	user, err := h.DBPool.FindUserForKey(s.User(), key)
141	if err != nil {
142		return err
143	}
144
145	if user.Name == "" {
146		return fmt.Errorf("must have username set")
147	}
148
149	s.Permissions().Extensions["user_id"] = user.ID
150	return nil
151}
152
153type KeyWithId struct {
154	Pk      ssh.PublicKey
155	ID      string
156	Comment string
157}
158
159type KeyDiffResult struct {
160	Add    []KeyWithId
161	Rm     []string
162	Update []KeyWithId
163}
164
165func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []KeyWithId) KeyDiffResult {
166	update := []KeyWithId{}
167	add := []KeyWithId{}
168	for _, nk := range nextKeys {
169		found := false
170		for _, ck := range curKeys {
171			if pssh.KeysEqual(nk.Pk, ck.Pk) {
172				found = true
173
174				// update the comment field
175				if nk.Comment != ck.Comment {
176					ck.Comment = nk.Comment
177					update = append(update, ck)
178				}
179				break
180			}
181		}
182		if !found {
183			add = append(add, nk)
184		}
185	}
186
187	rm := []string{}
188	for _, ck := range curKeys {
189		// we never want to remove the key that's in the current ssh session
190		// in an effort to avoid mistakenly removing their current key
191		if pssh.KeysEqual(ck.Pk, keyInUse) {
192			continue
193		}
194
195		found := false
196		for _, nk := range nextKeys {
197			if pssh.KeysEqual(ck.Pk, nk.Pk) {
198				found = true
199				break
200			}
201		}
202		if !found {
203			rm = append(rm, ck.ID)
204		}
205	}
206
207	return KeyDiffResult{
208		Add:    add,
209		Rm:     rm,
210		Update: update,
211	}
212}
213
214func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s *pssh.SSHServerConnSession) error {
215	logger.Info("processing new authorized_keys")
216	dbpool := h.DBPool
217
218	curKeysStr, err := dbpool.FindKeysForUser(user)
219	if err != nil {
220		return err
221	}
222
223	splitKeys := bytes.Split(text, []byte{'\n'})
224	nextKeys := []KeyWithId{}
225	for _, pk := range splitKeys {
226		key, comment, _, _, err := ssh.ParseAuthorizedKey(bytes.TrimSpace(pk))
227		if err != nil {
228			continue
229		}
230		nextKeys = append(nextKeys, KeyWithId{Pk: key, Comment: comment})
231	}
232
233	curKeys := []KeyWithId{}
234	for _, pk := range curKeysStr {
235		key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
236		if err != nil {
237			continue
238		}
239		curKeys = append(curKeys, KeyWithId{Pk: key, ID: pk.ID, Comment: pk.Name})
240	}
241
242	diff := authorizedKeysDiff(s.PublicKey(), curKeys, nextKeys)
243
244	for _, pk := range diff.Add {
245		key := utils.KeyForKeyText(pk.Pk)
246
247		fmt.Fprintf(s.Stderr(), "adding pubkey (%s)\n", key)
248		logger.Info("adding pubkey", "pubkey", key)
249
250		err = dbpool.InsertPublicKey(user.ID, key, pk.Comment, nil)
251		if err != nil {
252			fmt.Fprintf(s.Stderr(), "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
253			logger.Error("could not insert pubkey", "err", err.Error())
254		}
255	}
256
257	for _, pk := range diff.Update {
258		key := utils.KeyForKeyText(pk.Pk)
259
260		fmt.Fprintf(s.Stderr(), "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
261		logger.Info(
262			"updating pubkey with comment",
263			"pubkey", key,
264			"comment", pk.Comment,
265		)
266
267		_, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
268		if err != nil {
269			fmt.Fprintf(s.Stderr(), "error: could not update pubkey: %s (%s)\n", err.Error(), key)
270			logger.Error("could not update pubkey", "err", err.Error(), "key", key)
271		}
272	}
273
274	if len(diff.Rm) > 0 {
275		fmt.Fprintf(s.Stderr(), "removing pubkeys: %s\n", diff.Rm)
276		logger.Info("removing pubkeys", "pubkeys", diff.Rm)
277
278		err = dbpool.RemoveKeys(diff.Rm)
279		if err != nil {
280			fmt.Fprintf(s.Stderr(), "error: could not rm pubkeys: %s\n", err.Error())
281			logger.Error("could not remove pubkey", "err", err.Error())
282		}
283	}
284
285	return nil
286}
287
288func (h *UploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
289	logger := pssh.GetLogger(s)
290	user := pssh.GetUser(s)
291
292	if user == nil {
293		err := fmt.Errorf("could not get user from ctx")
294		logger.Error("error getting user from ctx", "err", err)
295		return "", err
296	}
297
298	filename := filepath.Base(entry.Filepath)
299	logger = logger.With(
300		"user", user.Name,
301		"filename", filename,
302	)
303
304	var text []byte
305	if b, err := io.ReadAll(entry.Reader); err == nil {
306		text = b
307	}
308
309	if filename == "authorized_keys" {
310		err := h.ProcessAuthorizedKeys(text, logger, user, s)
311		if err != nil {
312			return "", err
313		}
314	} else {
315		return "", fmt.Errorf("validation error: invalid file, received %s", entry.Filepath)
316	}
317
318	return "", nil
319}