repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / filehandlers
Eric Bower  ·  2026-01-25

post_handler.go

  1package filehandlers
  2
  3import (
  4	"encoding/binary"
  5	"fmt"
  6	"io"
  7	"net/http"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"github.com/picosh/pico/pkg/db"
 14	"github.com/picosh/pico/pkg/pssh"
 15	sendutils "github.com/picosh/pico/pkg/send/utils"
 16	"github.com/picosh/pico/pkg/shared"
 17)
 18
 19type PostMetaData struct {
 20	*db.Post
 21	Cur       *db.Post
 22	Tags      []string
 23	User      *db.User
 24	FileEntry *sendutils.FileEntry
 25	Aliases   []string
 26}
 27
 28type ScpFileHooks interface {
 29	FileValidate(s *pssh.SSHServerConnSession, data *PostMetaData) (bool, error)
 30	FileMeta(s *pssh.SSHServerConnSession, data *PostMetaData) error
 31}
 32
 33type ScpUploadHandler struct {
 34	DBPool db.DB
 35	Cfg    *shared.ConfigSite
 36	Hooks  ScpFileHooks
 37}
 38
 39func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks) *ScpUploadHandler {
 40	return &ScpUploadHandler{
 41		DBPool: dbpool,
 42		Cfg:    cfg,
 43		Hooks:  hooks,
 44	}
 45}
 46
 47func (r *ScpUploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
 48	return BaseList(s, fpath, isDir, recursive, []string{r.Cfg.Space}, r.DBPool)
 49}
 50
 51func (h *ScpUploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
 52	logger := pssh.GetLogger(s)
 53	user := pssh.GetUser(s)
 54
 55	if user == nil {
 56		err := fmt.Errorf("could not get user from ctx")
 57		logger.Error("error getting user from ctx", "err", err)
 58		return nil, nil, err
 59	}
 60
 61	cleanFilename := filepath.Base(entry.Filepath)
 62
 63	if cleanFilename == "" || cleanFilename == "." {
 64		return nil, nil, os.ErrNotExist
 65	}
 66
 67	post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 68	if err != nil {
 69		return nil, nil, err
 70	}
 71
 72	fileInfo := &sendutils.VirtualFile{
 73		FName:    post.Filename,
 74		FIsDir:   false,
 75		FSize:    int64(post.FileSize),
 76		FModTime: *post.UpdatedAt,
 77	}
 78
 79	reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(post.Text))
 80
 81	return fileInfo, reader, nil
 82}
 83
 84func (h *ScpUploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
 85	logger := pssh.GetLogger(s)
 86	user := pssh.GetUser(s)
 87
 88	if user == nil {
 89		err := fmt.Errorf("could not get user from ctx")
 90		logger.Error("error getting user from ctx", "err", err)
 91		return "", err
 92	}
 93
 94	userID := user.ID
 95	filename := filepath.Base(entry.Filepath)
 96
 97	logger = logger.With(
 98		"filename", filename,
 99	)
100
101	if entry.Mode.IsDir() {
102		return "", fmt.Errorf("file entry is directory, but only files are supported: %s", filename)
103	}
104
105	var origText []byte
106	if b, err := io.ReadAll(entry.Reader); err == nil {
107		origText = b
108	}
109
110	mimeType := http.DetectContentType(origText)
111	ext := filepath.Ext(filename)
112	// DetectContentType does not detect markdown
113	if ext == ".md" {
114		mimeType = "text/markdown; charset=UTF-8"
115	}
116
117	now := time.Now()
118	slug := shared.SanitizeFileExt(filename)
119	fileSize := binary.Size(origText)
120	shasum := shared.Shasum(origText)
121
122	nextPost := db.Post{
123		Filename:  filename,
124		Slug:      slug,
125		PublishAt: &now,
126		Text:      string(origText),
127		MimeType:  mimeType,
128		FileSize:  fileSize,
129		Shasum:    shasum,
130	}
131
132	metadata := PostMetaData{
133		Post:      &nextPost,
134		User:      user,
135		FileEntry: entry,
136	}
137
138	valid, err := h.Hooks.FileValidate(s, &metadata)
139	if !valid {
140		logger.Error("file failed validation", "err", err.Error())
141		return "", err
142	}
143
144	post, err := h.DBPool.FindPostWithFilename(metadata.Filename, metadata.User.ID, h.Cfg.Space)
145	if err != nil {
146		logger.Error("unable to load post, continuing", "err", err.Error())
147	}
148
149	if post != nil {
150		metadata.Cur = post
151		metadata.Data = post.Data
152		metadata.PublishAt = post.PublishAt
153	}
154
155	err = h.Hooks.FileMeta(s, &metadata)
156	if err != nil {
157		logger.Error("file could not load meta", "err", err.Error())
158		return "", err
159	}
160
161	modTime := time.Now()
162
163	if entry.Mtime > 0 {
164		modTime = time.Unix(entry.Mtime, 0)
165	}
166
167	if post == nil {
168		logger.Info("file not found, adding record")
169		insertPost := db.Post{
170			UserID: userID,
171			Space:  h.Cfg.Space,
172
173			Data:        metadata.Data,
174			Description: metadata.Description,
175			Filename:    metadata.Filename,
176			FileSize:    metadata.FileSize,
177			Hidden:      metadata.Hidden,
178			MimeType:    metadata.MimeType,
179			PublishAt:   metadata.PublishAt,
180			Shasum:      metadata.Shasum,
181			Slug:        metadata.Slug,
182			Text:        metadata.Text,
183			Title:       metadata.Title,
184			ExpiresAt:   metadata.ExpiresAt,
185			UpdatedAt:   &modTime,
186		}
187		post, err = h.DBPool.InsertPost(&insertPost)
188		if err != nil {
189			logger.Error("post could not be created", "err", err.Error())
190			return "", fmt.Errorf("error for %s: %v", filename, err)
191		}
192
193		if len(metadata.Aliases) > 0 {
194			logger.Info(
195				"found post aliases, replacing with old aliases",
196				"aliases",
197				strings.Join(metadata.Aliases, ","),
198			)
199			err = h.DBPool.ReplaceAliasesByPost(metadata.Aliases, post.ID)
200			if err != nil {
201				logger.Error("post could not replace aliases", "err", err.Error())
202				return "", fmt.Errorf("error for %s: %v", filename, err)
203			}
204		}
205
206		if len(metadata.Tags) > 0 {
207			logger.Info(
208				"found post tags, replacing with old tags",
209				"tags", strings.Join(metadata.Tags, ","),
210			)
211			err = h.DBPool.ReplaceTagsByPost(metadata.Tags, post.ID)
212			if err != nil {
213				logger.Error("post could not replace tags", "err", err.Error())
214				return "", fmt.Errorf("error for %s: %v", filename, err)
215			}
216		}
217	} else {
218		if metadata.Text == post.Text && modTime.Equal(*post.UpdatedAt) {
219			logger.Info("file found, but text is identical, skipping")
220			curl := shared.NewCreateURL(h.Cfg)
221			return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
222		}
223
224		logger.Info("file found, updating record")
225
226		updatePost := db.Post{
227			ID: post.ID,
228
229			Data:        metadata.Data,
230			FileSize:    metadata.FileSize,
231			Description: metadata.Description,
232			PublishAt:   metadata.PublishAt,
233			Slug:        metadata.Slug,
234			Shasum:      metadata.Shasum,
235			Text:        metadata.Text,
236			Title:       metadata.Title,
237			Hidden:      metadata.Hidden,
238			ExpiresAt:   metadata.ExpiresAt,
239			UpdatedAt:   &modTime,
240		}
241		_, err = h.DBPool.UpdatePost(&updatePost)
242		if err != nil {
243			logger.Error("post could not be updated", "err", err.Error())
244			return "", fmt.Errorf("error for %s: %v", filename, err)
245		}
246
247		logger.Info(
248			"found post tags, replacing with old tags",
249			"tags", strings.Join(metadata.Tags, ","),
250		)
251		err = h.DBPool.ReplaceTagsByPost(metadata.Tags, post.ID)
252		if err != nil {
253			logger.Error("post could not replace tags", "err", err.Error())
254			return "", fmt.Errorf("error for %s: %v", filename, err)
255		}
256
257		logger.Info(
258			"found post aliases, replacing with old aliases",
259			"aliases", strings.Join(metadata.Aliases, ","),
260		)
261		err = h.DBPool.ReplaceAliasesByPost(metadata.Aliases, post.ID)
262		if err != nil {
263			logger.Error("post could not replace aliases", "err", err.Error())
264			return "", fmt.Errorf("error for %s: %v", filename, err)
265		}
266	}
267
268	curl := shared.NewCreateURL(h.Cfg)
269	return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
270}
271
272func (h *ScpUploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
273	logger := pssh.GetLogger(s)
274	user := pssh.GetUser(s)
275
276	if user == nil {
277		err := fmt.Errorf("could not get user from ctx")
278		logger.Error("error getting user from ctx", "err", err)
279		return err
280	}
281
282	userID := user.ID
283	filename := filepath.Base(entry.Filepath)
284	logger = logger.With(
285		"filename", filename,
286	)
287
288	post, err := h.DBPool.FindPostWithFilename(filename, userID, h.Cfg.Space)
289	if err != nil {
290		return err
291	}
292
293	if post == nil {
294		return os.ErrNotExist
295	}
296
297	err = h.DBPool.RemovePosts([]string{post.ID})
298	logger.Info("removing record")
299	if err != nil {
300		logger.Error("post could not remove", "err", err.Error())
301		return fmt.Errorf("error for %s: %v", filename, err)
302	}
303	return nil
304}