Antonio Mika
·
2025-03-12
post_handler.go
1package filehandlers
2
3import (
4 "encoding/binary"
5 "fmt"
6 "io"
7 "net/http"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 "github.com/picosh/pico/pkg/db"
14 "github.com/picosh/pico/pkg/pssh"
15 sendutils "github.com/picosh/pico/pkg/send/utils"
16 "github.com/picosh/pico/pkg/shared"
17 "github.com/picosh/utils"
18)
19
20type PostMetaData struct {
21 *db.Post
22 Cur *db.Post
23 Tags []string
24 User *db.User
25 FileEntry *sendutils.FileEntry
26 Aliases []string
27}
28
29type ScpFileHooks interface {
30 FileValidate(s *pssh.SSHServerConnSession, data *PostMetaData) (bool, error)
31 FileMeta(s *pssh.SSHServerConnSession, data *PostMetaData) error
32}
33
34type ScpUploadHandler struct {
35 DBPool db.DB
36 Cfg *shared.ConfigSite
37 Hooks ScpFileHooks
38}
39
40func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks) *ScpUploadHandler {
41 return &ScpUploadHandler{
42 DBPool: dbpool,
43 Cfg: cfg,
44 Hooks: hooks,
45 }
46}
47
48func (r *ScpUploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
49 return BaseList(s, fpath, isDir, recursive, []string{r.Cfg.Space}, r.DBPool)
50}
51
52func (h *ScpUploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
53 logger := pssh.GetLogger(s)
54 user := pssh.GetUser(s)
55
56 if user == nil {
57 err := fmt.Errorf("could not get user from ctx")
58 logger.Error("error getting user from ctx", "err", err)
59 return nil, nil, err
60 }
61
62 cleanFilename := filepath.Base(entry.Filepath)
63
64 if cleanFilename == "" || cleanFilename == "." {
65 return nil, nil, os.ErrNotExist
66 }
67
68 post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
69 if err != nil {
70 return nil, nil, err
71 }
72
73 fileInfo := &sendutils.VirtualFile{
74 FName: post.Filename,
75 FIsDir: false,
76 FSize: int64(post.FileSize),
77 FModTime: *post.UpdatedAt,
78 }
79
80 reader := sendutils.NopReadAndReaderAtCloser(strings.NewReader(post.Text))
81
82 return fileInfo, reader, nil
83}
84
85func (h *ScpUploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
86 logger := pssh.GetLogger(s)
87 user := pssh.GetUser(s)
88
89 if user == nil {
90 err := fmt.Errorf("could not get user from ctx")
91 logger.Error("error getting user from ctx", "err", err)
92 return "", err
93 }
94
95 userID := user.ID
96 filename := filepath.Base(entry.Filepath)
97
98 logger = logger.With(
99 "filename", filename,
100 )
101
102 if entry.Mode.IsDir() {
103 return "", fmt.Errorf("file entry is directory, but only files are supported: %s", filename)
104 }
105
106 var origText []byte
107 if b, err := io.ReadAll(entry.Reader); err == nil {
108 origText = b
109 }
110
111 mimeType := http.DetectContentType(origText)
112 ext := filepath.Ext(filename)
113 // DetectContentType does not detect markdown
114 if ext == ".md" {
115 mimeType = "text/markdown; charset=UTF-8"
116 }
117
118 now := time.Now()
119 slug := utils.SanitizeFileExt(filename)
120 fileSize := binary.Size(origText)
121 shasum := utils.Shasum(origText)
122
123 nextPost := db.Post{
124 Filename: filename,
125 Slug: slug,
126 PublishAt: &now,
127 Text: string(origText),
128 MimeType: mimeType,
129 FileSize: fileSize,
130 Shasum: shasum,
131 }
132
133 metadata := PostMetaData{
134 Post: &nextPost,
135 User: user,
136 FileEntry: entry,
137 }
138
139 valid, err := h.Hooks.FileValidate(s, &metadata)
140 if !valid {
141 logger.Error("file failed validation", "err", err.Error())
142 return "", err
143 }
144
145 post, err := h.DBPool.FindPostWithFilename(metadata.Filename, metadata.User.ID, h.Cfg.Space)
146 if err != nil {
147 logger.Error("unable to load post, continuing", "err", err.Error())
148 }
149
150 if post != nil {
151 metadata.Cur = post
152 metadata.Data = post.Data
153 metadata.Post.PublishAt = post.PublishAt
154 }
155
156 err = h.Hooks.FileMeta(s, &metadata)
157 if err != nil {
158 logger.Error("file could not load meta", "err", err.Error())
159 return "", err
160 }
161
162 modTime := time.Now()
163
164 if entry.Mtime > 0 {
165 modTime = time.Unix(entry.Mtime, 0)
166 }
167
168 if post == nil {
169 logger.Info("file not found, adding record")
170 insertPost := db.Post{
171 UserID: userID,
172 Space: h.Cfg.Space,
173
174 Data: metadata.Data,
175 Description: metadata.Description,
176 Filename: metadata.Filename,
177 FileSize: metadata.FileSize,
178 Hidden: metadata.Hidden,
179 MimeType: metadata.MimeType,
180 PublishAt: metadata.PublishAt,
181 Shasum: metadata.Shasum,
182 Slug: metadata.Slug,
183 Text: metadata.Text,
184 Title: metadata.Title,
185 ExpiresAt: metadata.ExpiresAt,
186 UpdatedAt: &modTime,
187 }
188 post, err = h.DBPool.InsertPost(&insertPost)
189 if err != nil {
190 logger.Error("post could not be created", "err", err.Error())
191 return "", fmt.Errorf("error for %s: %v", filename, err)
192 }
193
194 if len(metadata.Aliases) > 0 {
195 logger.Info(
196 "found post aliases, replacing with old aliases",
197 "aliases",
198 strings.Join(metadata.Aliases, ","),
199 )
200 err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
201 if err != nil {
202 logger.Error("post could not replace aliases", "err", err.Error())
203 return "", fmt.Errorf("error for %s: %v", filename, err)
204 }
205 }
206
207 if len(metadata.Tags) > 0 {
208 logger.Info(
209 "found post tags, replacing with old tags",
210 "tags", strings.Join(metadata.Tags, ","),
211 )
212 err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
213 if err != nil {
214 logger.Error("post could not replace tags", "err", err.Error())
215 return "", fmt.Errorf("error for %s: %v", filename, err)
216 }
217 }
218 } else {
219 if metadata.Text == post.Text && modTime.Equal(*post.UpdatedAt) {
220 logger.Info("file found, but text is identical, skipping")
221 curl := shared.NewCreateURL(h.Cfg)
222 return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
223 }
224
225 logger.Info("file found, updating record")
226
227 updatePost := db.Post{
228 ID: post.ID,
229
230 Data: metadata.Data,
231 FileSize: metadata.FileSize,
232 Description: metadata.Description,
233 PublishAt: metadata.PublishAt,
234 Slug: metadata.Slug,
235 Shasum: metadata.Shasum,
236 Text: metadata.Text,
237 Title: metadata.Title,
238 Hidden: metadata.Hidden,
239 ExpiresAt: metadata.ExpiresAt,
240 UpdatedAt: &modTime,
241 }
242 _, err = h.DBPool.UpdatePost(&updatePost)
243 if err != nil {
244 logger.Error("post could not be updated", "err", err.Error())
245 return "", fmt.Errorf("error for %s: %v", filename, err)
246 }
247
248 logger.Info(
249 "found post tags, replacing with old tags",
250 "tags", strings.Join(metadata.Tags, ","),
251 )
252 err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
253 if err != nil {
254 logger.Error("post could not replace tags", "err", err.Error())
255 return "", fmt.Errorf("error for %s: %v", filename, err)
256 }
257
258 logger.Info(
259 "found post aliases, replacing with old aliases",
260 "aliases", strings.Join(metadata.Aliases, ","),
261 )
262 err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
263 if err != nil {
264 logger.Error("post could not replace aliases", "err", err.Error())
265 return "", fmt.Errorf("error for %s: %v", filename, err)
266 }
267 }
268
269 curl := shared.NewCreateURL(h.Cfg)
270 return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
271}
272
273func (h *ScpUploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
274 logger := pssh.GetLogger(s)
275 user := pssh.GetUser(s)
276
277 if user == nil {
278 err := fmt.Errorf("could not get user from ctx")
279 logger.Error("error getting user from ctx", "err", err)
280 return err
281 }
282
283 userID := user.ID
284 filename := filepath.Base(entry.Filepath)
285 logger = logger.With(
286 "filename", filename,
287 )
288
289 post, err := h.DBPool.FindPostWithFilename(filename, userID, h.Cfg.Space)
290 if err != nil {
291 return err
292 }
293
294 if post == nil {
295 return os.ErrNotExist
296 }
297
298 err = h.DBPool.RemovePosts([]string{post.ID})
299 logger.Info("removing record")
300 if err != nil {
301 logger.Error("post could not remove", "err", err.Error())
302 return fmt.Errorf("error for %s: %v", filename, err)
303 }
304 return nil
305}