repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2025-04-05

uploader.go

  1package pgs
  2
  3import (
  4	"bytes"
  5	"context"
  6	"fmt"
  7	"io"
  8	"io/fs"
  9	"log/slog"
 10	"os"
 11	"path"
 12	"path/filepath"
 13	"slices"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 19	"github.com/picosh/pico/pkg/db"
 20	"github.com/picosh/pico/pkg/pobj"
 21	sst "github.com/picosh/pico/pkg/pobj/storage"
 22	"github.com/picosh/pico/pkg/pssh"
 23	sendutils "github.com/picosh/pico/pkg/send/utils"
 24	"github.com/picosh/pico/pkg/shared"
 25	"github.com/picosh/utils"
 26	ignore "github.com/sabhiram/go-gitignore"
 27)
 28
 29type ctxBucketKey struct{}
 30type ctxStorageSizeKey struct{}
 31type ctxProjectKey struct{}
 32type ctxDenylistKey struct{}
 33
 34type DenyList struct {
 35	Denylist string
 36}
 37
 38func getDenylist(s *pssh.SSHServerConnSession) *DenyList {
 39	v := s.Context().Value(ctxDenylistKey{})
 40	if v == nil {
 41		return nil
 42	}
 43	denylist := s.Context().Value(ctxDenylistKey{}).(*DenyList)
 44	return denylist
 45}
 46
 47func setDenylist(s *pssh.SSHServerConnSession, denylist string) {
 48	s.SetValue(ctxDenylistKey{}, &DenyList{Denylist: denylist})
 49}
 50
 51func getProject(s *pssh.SSHServerConnSession) *db.Project {
 52	v := s.Context().Value(ctxProjectKey{})
 53	if v == nil {
 54		return nil
 55	}
 56	project := s.Context().Value(ctxProjectKey{}).(*db.Project)
 57	return project
 58}
 59
 60func setProject(s *pssh.SSHServerConnSession, project *db.Project) {
 61	s.SetValue(ctxProjectKey{}, project)
 62}
 63
 64func getBucket(s *pssh.SSHServerConnSession) (sst.Bucket, error) {
 65	bucket := s.Context().Value(ctxBucketKey{}).(sst.Bucket)
 66	if bucket.Name == "" {
 67		return bucket, fmt.Errorf("bucket not set on `ssh.Context()` for connection")
 68	}
 69	return bucket, nil
 70}
 71
 72func getStorageSize(s *pssh.SSHServerConnSession) uint64 {
 73	return s.Context().Value(ctxStorageSizeKey{}).(uint64)
 74}
 75
 76func incrementStorageSize(s *pssh.SSHServerConnSession, fileSize int64) uint64 {
 77	curSize := getStorageSize(s)
 78	var nextStorageSize uint64
 79	if fileSize < 0 {
 80		nextStorageSize = curSize - uint64(fileSize)
 81	} else {
 82		nextStorageSize = curSize + uint64(fileSize)
 83	}
 84	s.SetValue(ctxStorageSizeKey{}, nextStorageSize)
 85	return nextStorageSize
 86}
 87
 88func shouldIgnoreFile(fp, ignoreStr string) bool {
 89	object := ignore.CompileIgnoreLines(strings.Split(ignoreStr, "\n")...)
 90	return object.MatchesPath(fp)
 91}
 92
 93type FileData struct {
 94	*sendutils.FileEntry
 95	User     *db.User
 96	Bucket   sst.Bucket
 97	Project  *db.Project
 98	DenyList string
 99}
100
101type UploadAssetHandler struct {
102	Cfg                *PgsConfig
103	CacheClearingQueue chan string
104}
105
106func NewUploadAssetHandler(cfg *PgsConfig, ch chan string, ctx context.Context) *UploadAssetHandler {
107	go runCacheQueue(cfg, ctx)
108	return &UploadAssetHandler{
109		Cfg:                cfg,
110		CacheClearingQueue: ch,
111	}
112}
113
114func (h *UploadAssetHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
115	return pssh.GetLogger(s)
116}
117
118func (h *UploadAssetHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
119	logger := pssh.GetLogger(s)
120	user := pssh.GetUser(s)
121
122	if user == nil {
123		err := fmt.Errorf("could not get user from ctx")
124		logger.Error("error getting user from ctx", "err", err)
125		return nil, nil, err
126	}
127
128	fileInfo := &sendutils.VirtualFile{
129		FName:    filepath.Base(entry.Filepath),
130		FIsDir:   false,
131		FSize:    entry.Size,
132		FModTime: time.Unix(entry.Mtime, 0),
133	}
134
135	bucket, err := h.Cfg.Storage.GetBucket(shared.GetAssetBucketName(user.ID))
136	if err != nil {
137		return nil, nil, err
138	}
139
140	fname := shared.GetAssetFileName(entry)
141	contents, info, err := h.Cfg.Storage.GetObject(bucket, fname)
142	if err != nil {
143		return nil, nil, err
144	}
145
146	fileInfo.FSize = info.Size
147	fileInfo.FModTime = info.LastModified
148
149	reader := pobj.NewAllReaderAt(contents)
150
151	return fileInfo, reader, nil
152}
153
154func (h *UploadAssetHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
155	var fileList []os.FileInfo
156
157	logger := pssh.GetLogger(s)
158	user := pssh.GetUser(s)
159
160	if user == nil {
161		err := fmt.Errorf("could not get user from ctx")
162		logger.Error("error getting user from ctx", "err", err)
163		return fileList, err
164	}
165
166	cleanFilename := fpath
167
168	bucketName := shared.GetAssetBucketName(user.ID)
169	bucket, err := h.Cfg.Storage.GetBucket(bucketName)
170	if err != nil {
171		return fileList, err
172	}
173
174	if cleanFilename == "" || cleanFilename == "." {
175		name := cleanFilename
176		if name == "" {
177			name = "/"
178		}
179
180		info := &sendutils.VirtualFile{
181			FName:  name,
182			FIsDir: true,
183		}
184
185		fileList = append(fileList, info)
186	} else {
187		if cleanFilename != "/" && isDir {
188			cleanFilename += "/"
189		}
190
191		foundList, err := h.Cfg.Storage.ListObjects(bucket, cleanFilename, recursive)
192		if err != nil {
193			return fileList, err
194		}
195
196		fileList = append(fileList, foundList...)
197	}
198
199	return fileList, nil
200}
201
202func (h *UploadAssetHandler) Validate(s *pssh.SSHServerConnSession) error {
203	logger := pssh.GetLogger(s)
204	user := pssh.GetUser(s)
205
206	if user == nil {
207		err := fmt.Errorf("could not get user from ctx")
208		logger.Error("error getting user from ctx", "err", err)
209		return err
210	}
211
212	assetBucket := shared.GetAssetBucketName(user.ID)
213	bucket, err := h.Cfg.Storage.UpsertBucket(assetBucket)
214	if err != nil {
215		return err
216	}
217
218	s.SetValue(ctxBucketKey{}, bucket)
219
220	totalStorageSize, err := h.Cfg.Storage.GetBucketQuota(bucket)
221	if err != nil {
222		return err
223	}
224
225	s.SetValue(ctxStorageSizeKey{}, totalStorageSize)
226
227	logger.Info(
228		"bucket size",
229		"user", user.Name,
230		"bytes", totalStorageSize,
231	)
232
233	logger.Info(
234		"attempting to upload files",
235		"user", user.Name,
236		"txtPrefix", h.Cfg.TxtPrefix,
237	)
238
239	return nil
240}
241
242func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project, logger *slog.Logger) (string, error) {
243	fp, _, err := h.Cfg.Storage.GetObject(bucket, filepath.Join(project.ProjectDir, "_pgs_ignore"))
244	if err != nil {
245		return "", fmt.Errorf("_pgs_ignore not found")
246	}
247	defer fp.Close()
248
249	buf := new(strings.Builder)
250	_, err = io.Copy(buf, fp)
251	if err != nil {
252		logger.Error("io copy", "err", err.Error())
253		return "", err
254	}
255
256	str := buf.String()
257	return str, nil
258}
259
260func findPlusFF(dbpool pgsdb.PgsDB, cfg *PgsConfig, userID string) *db.FeatureFlag {
261	ff, _ := dbpool.FindFeature(userID, "plus")
262	// we have free tiers so users might not have a feature flag
263	// in which case we set sane defaults
264	if ff == nil {
265		ff = db.NewFeatureFlag(
266			userID,
267			"plus",
268			cfg.MaxSize,
269			cfg.MaxAssetSize,
270			cfg.MaxSpecialFileSize,
271		)
272	}
273	// this is jank
274	ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize)
275	ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize)
276	ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize)
277	return ff
278}
279
280func (h *UploadAssetHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
281	logger := pssh.GetLogger(s)
282	user := pssh.GetUser(s)
283
284	if user == nil {
285		err := fmt.Errorf("could not get user from ctx")
286		logger.Error("error getting user from ctx", "err", err)
287		return "", err
288	}
289
290	if entry.Mode.IsDir() && strings.Count(entry.Filepath, "/") == 1 {
291		entry.Filepath = strings.TrimPrefix(entry.Filepath, "/")
292	}
293
294	logger = logger.With(
295		"file", entry.Filepath,
296		"size", entry.Size,
297	)
298
299	bucket, err := getBucket(s)
300	if err != nil {
301		logger.Error("could not find bucket in ctx", "err", err.Error())
302		return "", err
303	}
304
305	project := getProject(s)
306	projectName := shared.GetProjectName(entry)
307	logger = logger.With("project", projectName)
308
309	// find, create, or update project if we haven't already done it
310	// we need to also check if the project stored in ctx is the same project
311	// being uploaded since users can keep an ssh connection alive via sftp
312	// and created many projects in a single session
313	if project == nil || project.Name != projectName {
314		project, err = h.Cfg.DB.UpsertProject(user.ID, projectName, projectName)
315		if err != nil {
316			logger.Error("upsert project", "err", err.Error())
317			return "", err
318		}
319		setProject(s, project)
320	}
321
322	if project.Blocked != "" {
323		msg := "project has been blocked and cannot upload files: %s"
324		return "", fmt.Errorf(msg, project.Blocked)
325	}
326
327	if entry.Mode.IsDir() {
328		_, _, err := h.Cfg.Storage.PutObject(
329			bucket,
330			path.Join(shared.GetAssetFileName(entry), "._pico_keep_dir"),
331			bytes.NewReader([]byte{}),
332			entry,
333		)
334		return "", err
335	}
336
337	featureFlag := findPlusFF(h.Cfg.DB, h.Cfg, user.ID)
338	// calculate the filsize difference between the same file already
339	// stored and the updated file being uploaded
340	assetFilename := shared.GetAssetFileName(entry)
341	obj, info, _ := h.Cfg.Storage.GetObject(bucket, assetFilename)
342	var curFileSize int64
343	if info != nil {
344		curFileSize = info.Size
345	}
346	if obj != nil {
347		defer obj.Close()
348	}
349
350	denylist := getDenylist(s)
351	if denylist == nil {
352		dlist, err := h.findDenylist(bucket, project, logger)
353		if err != nil {
354			logger.Info("failed to get denylist, setting default (.*)", "err", err.Error())
355			dlist = ".*"
356		}
357		setDenylist(s, dlist)
358		denylist = &DenyList{Denylist: dlist}
359	}
360
361	data := &FileData{
362		FileEntry: entry,
363		User:      user,
364		Bucket:    bucket,
365		DenyList:  denylist.Denylist,
366		Project:   project,
367	}
368
369	valid, err := h.validateAsset(data)
370	if !valid {
371		return "", err
372	}
373
374	// SFTP does not report file size so the more performant way to
375	//   check filesize constraints is to try and upload the file to s3
376	//	 with a specialized reader that raises an error if the filesize limit
377	//	 has been reached
378	storageMax := featureFlag.Data.StorageMax
379	fileMax := featureFlag.Data.FileMax
380	curStorageSize := getStorageSize(s)
381	remaining := int64(storageMax) - int64(curStorageSize)
382	sizeRemaining := min(remaining+curFileSize, fileMax)
383	if sizeRemaining <= 0 {
384		fmt.Fprintln(s.Stderr(), "storage quota reached")
385		fmt.Fprintf(s.Stderr(), "\r")
386		_ = s.Exit(1)
387		_ = s.Close()
388		return "", fmt.Errorf("storage quota reached")
389	}
390	logger = logger.With(
391		"storageMax", storageMax,
392		"currentStorageMax", curStorageSize,
393		"fileMax", fileMax,
394		"sizeRemaining", sizeRemaining,
395	)
396
397	specialFileMax := featureFlag.Data.SpecialFileMax
398	if isSpecialFile(entry.Filepath) {
399		sizeRemaining = min(sizeRemaining, specialFileMax)
400	}
401
402	fsize, err := h.writeAsset(
403		s,
404		utils.NewMaxBytesReader(data.Reader, int64(sizeRemaining)),
405		data,
406	)
407	if err != nil {
408		logger.Error("could not write asset", "err", err.Error())
409		cerr := fmt.Errorf(
410			"%s: storage size %.2fmb, storage max %.2fmb, file max %.2fmb, special file max %.4fmb",
411			err,
412			utils.BytesToMB(int(curStorageSize)),
413			utils.BytesToMB(int(storageMax)),
414			utils.BytesToMB(int(fileMax)),
415			utils.BytesToMB(int(specialFileMax)),
416		)
417		return "", cerr
418	}
419
420	deltaFileSize := curFileSize - fsize
421	nextStorageSize := incrementStorageSize(s, deltaFileSize)
422
423	url := h.Cfg.AssetURL(
424		user.Name,
425		projectName,
426		strings.Replace(data.Filepath, "/"+projectName+"/", "", 1),
427	)
428
429	maxSize := int(featureFlag.Data.StorageMax)
430	str := fmt.Sprintf(
431		"%s (space: %.2f/%.2fGB, %.2f%%)",
432		url,
433		utils.BytesToGB(int(nextStorageSize)),
434		utils.BytesToGB(maxSize),
435		(float32(nextStorageSize)/float32(maxSize))*100,
436	)
437
438	surrogate := getSurrogateKey(user.Name, projectName)
439	h.Cfg.CacheClearingQueue <- surrogate
440
441	return str, err
442}
443
444func isSpecialFile(entry string) bool {
445	fname := filepath.Base(entry)
446	return fname == "_headers" || fname == "_redirects" || fname == "_pgs_ignore"
447}
448
449func (h *UploadAssetHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
450	logger := pssh.GetLogger(s)
451	user := pssh.GetUser(s)
452
453	if user == nil {
454		err := fmt.Errorf("could not get user from ctx")
455		logger.Error("error getting user from ctx", "err", err)
456		return err
457	}
458
459	if entry.Mode.IsDir() && strings.Count(entry.Filepath, "/") == 1 {
460		entry.Filepath = strings.TrimPrefix(entry.Filepath, "/")
461	}
462
463	assetFilepath := shared.GetAssetFileName(entry)
464
465	logger = logger.With(
466		"file", assetFilepath,
467	)
468
469	bucket, err := getBucket(s)
470	if err != nil {
471		logger.Error("could not find bucket in ctx", "err", err.Error())
472		return err
473	}
474
475	projectName := shared.GetProjectName(entry)
476	logger = logger.With("project", projectName)
477
478	if assetFilepath == filepath.Join("/", projectName, "._pico_keep_dir") {
479		return os.ErrPermission
480	}
481
482	logger.Info("deleting file")
483
484	pathDir := filepath.Dir(assetFilepath)
485	fileName := filepath.Base(assetFilepath)
486
487	sibs, err := h.Cfg.Storage.ListObjects(bucket, pathDir+"/", false)
488	if err != nil {
489		return err
490	}
491
492	sibs = slices.DeleteFunc(sibs, func(sib fs.FileInfo) bool {
493		return sib.Name() == fileName
494	})
495
496	if len(sibs) == 0 {
497		_, _, err := h.Cfg.Storage.PutObject(
498			bucket,
499			filepath.Join(pathDir, "._pico_keep_dir"),
500			bytes.NewReader([]byte{}),
501			entry,
502		)
503		if err != nil {
504			return err
505		}
506	}
507	err = h.Cfg.Storage.DeleteObject(bucket, assetFilepath)
508
509	surrogate := getSurrogateKey(user.Name, projectName)
510	h.Cfg.CacheClearingQueue <- surrogate
511
512	if err != nil {
513		return err
514	}
515
516	return err
517}
518
519func (h *UploadAssetHandler) validateAsset(data *FileData) (bool, error) {
520	fname := filepath.Base(data.Filepath)
521
522	projectName := shared.GetProjectName(data.FileEntry)
523	if projectName == "" || projectName == "/" || projectName == "." {
524		return false, fmt.Errorf("ERROR: invalid project name, you must copy files to a non-root folder (e.g. pgs.sh:/project-name)")
525	}
526
527	// special files we use for custom routing
528	if isSpecialFile(fname) {
529		return true, nil
530	}
531
532	fpath := strings.Replace(data.Filepath, "/"+projectName, "", 1)
533	if shouldIgnoreFile(fpath, data.DenyList) {
534		err := fmt.Errorf(
535			"ERROR: (%s) file rejected, https://pico.sh/pgs#-pgs-ignore",
536			data.Filepath,
537		)
538		return false, err
539	}
540
541	return true, nil
542}
543
544func (h *UploadAssetHandler) writeAsset(s *pssh.SSHServerConnSession, reader io.Reader, data *FileData) (int64, error) {
545	assetFilepath := shared.GetAssetFileName(data.FileEntry)
546
547	logger := h.GetLogger(s)
548	logger.Info(
549		"uploading file to bucket",
550		"bucket", data.Bucket.Name,
551		"filename", assetFilepath,
552	)
553
554	_, fsize, err := h.Cfg.Storage.PutObject(
555		data.Bucket,
556		assetFilepath,
557		reader,
558		data.FileEntry,
559	)
560	return fsize, err
561}
562
563// runCacheQueue processes requests to purge the cache for a single site.
564// One message arrives per file that is written/deleted during uploads.
565// Repeated messages for the same site are grouped so that we only flush once
566// per site per 5 seconds.
567func runCacheQueue(cfg *PgsConfig, ctx context.Context) {
568	send := createPubCacheDrain(ctx, cfg.Logger)
569	var pendingFlushes sync.Map
570	tick := time.Tick(5 * time.Second)
571	for {
572		select {
573		case host := <-cfg.CacheClearingQueue:
574			pendingFlushes.Store(host, host)
575		case <-tick:
576			go func() {
577				pendingFlushes.Range(func(key, value any) bool {
578					pendingFlushes.Delete(key)
579					err := purgeCache(cfg, send, key.(string))
580					if err != nil {
581						cfg.Logger.Error("failed to clear cache", "err", err.Error())
582					}
583					return true
584				})
585			}()
586		}
587	}
588}