repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / filehandlers
Antonio Mika  ·  2025-03-12

post_handler.go

  1package filehandlers
  2
  3import (
  4	"encoding/binary"
  5	"fmt"
  6	"io"
  7	"net/http"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"github.com/picosh/pico/pkg/db"
 14	"github.com/picosh/pico/pkg/pssh"
 15	sendutils "github.com/picosh/pico/pkg/send/utils"
 16	"github.com/picosh/pico/pkg/shared"
 17	"github.com/picosh/utils"
 18)
 19
 20type PostMetaData struct {
 21	*db.Post
 22	Cur       *db.Post
 23	Tags      []string
 24	User      *db.User
 25	FileEntry *sendutils.FileEntry
 26	Aliases   []string
 27}
 28
 29type ScpFileHooks interface {
 30	FileValidate(s *pssh.SSHServerConnSession, data *PostMetaData) (bool, error)
 31	FileMeta(s *pssh.SSHServerConnSession, data *PostMetaData) error
 32}
 33
 34type ScpUploadHandler struct {
 35	DBPool db.DB
 36	Cfg    *shared.ConfigSite
 37	Hooks  ScpFileHooks
 38}
 39
 40func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks) *ScpUploadHandler {
 41	return &ScpUploadHandler{
 42		DBPool: dbpool,
 43		Cfg:    cfg,
 44		Hooks:  hooks,
 45	}
 46}
 47
 48func (r *ScpUploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
 49	return BaseList(s, fpath, isDir, recursive, []string{r.Cfg.Space}, r.DBPool)
 50}
 51
 52func (h *ScpUploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
 53	logger := pssh.GetLogger(s)
 54	user := pssh.GetUser(s)
 55
 56	if user == nil {
 57		err := fmt.Errorf("could not get user from ctx")
 58		logger.Error("error getting user from ctx", "err", err)
 59		return nil, nil, err
 60	}
 61
 62	cleanFilename := filepath.Base(entry.Filepath)
 63
 64	if cleanFilename == "" || cleanFilename == "." {
 65		return nil, nil, os.ErrNotExist
 66	}
 67
 68	post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 69	if err != nil {
 70		return nil, nil, err
 71	}
 72
 73	fileInfo := &sendutils.VirtualFile{
 74		FName:    post.Filename,
 75		FIsDir:   false,
 76		FSize:    int64(post.FileSize),
 77		FModTime: *post.UpdatedAt,
 78	}
 79
 80	reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(post.Text))
 81
 82	return fileInfo, reader, nil
 83}
 84
 85func (h *ScpUploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
 86	logger := pssh.GetLogger(s)
 87	user := pssh.GetUser(s)
 88
 89	if user == nil {
 90		err := fmt.Errorf("could not get user from ctx")
 91		logger.Error("error getting user from ctx", "err", err)
 92		return "", err
 93	}
 94
 95	userID := user.ID
 96	filename := filepath.Base(entry.Filepath)
 97
 98	logger = logger.With(
 99		"filename", filename,
100	)
101
102	if entry.Mode.IsDir() {
103		return "", fmt.Errorf("file entry is directory, but only files are supported: %s", filename)
104	}
105
106	var origText []byte
107	if b, err := io.ReadAll(entry.Reader); err == nil {
108		origText = b
109	}
110
111	mimeType := http.DetectContentType(origText)
112	ext := filepath.Ext(filename)
113	// DetectContentType does not detect markdown
114	if ext == ".md" {
115		mimeType = "text/markdown; charset=UTF-8"
116	}
117
118	now := time.Now()
119	slug := utils.SanitizeFileExt(filename)
120	fileSize := binary.Size(origText)
121	shasum := utils.Shasum(origText)
122
123	nextPost := db.Post{
124		Filename:  filename,
125		Slug:      slug,
126		PublishAt: &now,
127		Text:      string(origText),
128		MimeType:  mimeType,
129		FileSize:  fileSize,
130		Shasum:    shasum,
131	}
132
133	metadata := PostMetaData{
134		Post:      &nextPost,
135		User:      user,
136		FileEntry: entry,
137	}
138
139	valid, err := h.Hooks.FileValidate(s, &metadata)
140	if !valid {
141		logger.Error("file failed validation", "err", err.Error())
142		return "", err
143	}
144
145	post, err := h.DBPool.FindPostWithFilename(metadata.Filename, metadata.User.ID, h.Cfg.Space)
146	if err != nil {
147		logger.Error("unable to load post, continuing", "err", err.Error())
148	}
149
150	if post != nil {
151		metadata.Cur = post
152		metadata.Data = post.Data
153		metadata.Post.PublishAt = post.PublishAt
154	}
155
156	err = h.Hooks.FileMeta(s, &metadata)
157	if err != nil {
158		logger.Error("file could not load meta", "err", err.Error())
159		return "", err
160	}
161
162	modTime := time.Now()
163
164	if entry.Mtime > 0 {
165		modTime = time.Unix(entry.Mtime, 0)
166	}
167
168	if post == nil {
169		logger.Info("file not found, adding record")
170		insertPost := db.Post{
171			UserID: userID,
172			Space:  h.Cfg.Space,
173
174			Data:        metadata.Data,
175			Description: metadata.Description,
176			Filename:    metadata.Filename,
177			FileSize:    metadata.FileSize,
178			Hidden:      metadata.Hidden,
179			MimeType:    metadata.MimeType,
180			PublishAt:   metadata.PublishAt,
181			Shasum:      metadata.Shasum,
182			Slug:        metadata.Slug,
183			Text:        metadata.Text,
184			Title:       metadata.Title,
185			ExpiresAt:   metadata.ExpiresAt,
186			UpdatedAt:   &modTime,
187		}
188		post, err = h.DBPool.InsertPost(&insertPost)
189		if err != nil {
190			logger.Error("post could not be created", "err", err.Error())
191			return "", fmt.Errorf("error for %s: %v", filename, err)
192		}
193
194		if len(metadata.Aliases) > 0 {
195			logger.Info(
196				"found post aliases, replacing with old aliases",
197				"aliases",
198				strings.Join(metadata.Aliases, ","),
199			)
200			err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
201			if err != nil {
202				logger.Error("post could not replace aliases", "err", err.Error())
203				return "", fmt.Errorf("error for %s: %v", filename, err)
204			}
205		}
206
207		if len(metadata.Tags) > 0 {
208			logger.Info(
209				"found post tags, replacing with old tags",
210				"tags", strings.Join(metadata.Tags, ","),
211			)
212			err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
213			if err != nil {
214				logger.Error("post could not replace tags", "err", err.Error())
215				return "", fmt.Errorf("error for %s: %v", filename, err)
216			}
217		}
218	} else {
219		if metadata.Text == post.Text && modTime.Equal(*post.UpdatedAt) {
220			logger.Info("file found, but text is identical, skipping")
221			curl := shared.NewCreateURL(h.Cfg)
222			return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
223		}
224
225		logger.Info("file found, updating record")
226
227		updatePost := db.Post{
228			ID: post.ID,
229
230			Data:        metadata.Data,
231			FileSize:    metadata.FileSize,
232			Description: metadata.Description,
233			PublishAt:   metadata.PublishAt,
234			Slug:        metadata.Slug,
235			Shasum:      metadata.Shasum,
236			Text:        metadata.Text,
237			Title:       metadata.Title,
238			Hidden:      metadata.Hidden,
239			ExpiresAt:   metadata.ExpiresAt,
240			UpdatedAt:   &modTime,
241		}
242		_, err = h.DBPool.UpdatePost(&updatePost)
243		if err != nil {
244			logger.Error("post could not be updated", "err", err.Error())
245			return "", fmt.Errorf("error for %s: %v", filename, err)
246		}
247
248		logger.Info(
249			"found post tags, replacing with old tags",
250			"tags", strings.Join(metadata.Tags, ","),
251		)
252		err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
253		if err != nil {
254			logger.Error("post could not replace tags", "err", err.Error())
255			return "", fmt.Errorf("error for %s: %v", filename, err)
256		}
257
258		logger.Info(
259			"found post aliases, replacing with old aliases",
260			"aliases", strings.Join(metadata.Aliases, ","),
261		)
262		err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
263		if err != nil {
264			logger.Error("post could not replace aliases", "err", err.Error())
265			return "", fmt.Errorf("error for %s: %v", filename, err)
266		}
267	}
268
269	curl := shared.NewCreateURL(h.Cfg)
270	return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
271}
272
273func (h *ScpUploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
274	logger := pssh.GetLogger(s)
275	user := pssh.GetUser(s)
276
277	if user == nil {
278		err := fmt.Errorf("could not get user from ctx")
279		logger.Error("error getting user from ctx", "err", err)
280		return err
281	}
282
283	userID := user.ID
284	filename := filepath.Base(entry.Filepath)
285	logger = logger.With(
286		"filename", filename,
287	)
288
289	post, err := h.DBPool.FindPostWithFilename(filename, userID, h.Cfg.Space)
290	if err != nil {
291		return err
292	}
293
294	if post == nil {
295		return os.ErrNotExist
296	}
297
298	err = h.DBPool.RemovePosts([]string{post.ID})
299	logger.Info("removing record")
300	if err != nil {
301		logger.Error("post could not remove", "err", err.Error())
302		return fmt.Errorf("error for %s: %v", filename, err)
303	}
304	return nil
305}