Antonio Mika
·
2025-03-12
router_handler.go
1package filehandlers
2
3import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10
11 "github.com/picosh/pico/pkg/db"
12 "github.com/picosh/pico/pkg/pssh"
13 "github.com/picosh/pico/pkg/send/utils"
14 "github.com/picosh/pico/pkg/shared"
15)
16
17type ReadWriteHandler interface {
18 List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error)
19 Write(*pssh.SSHServerConnSession, *utils.FileEntry) (string, error)
20 Read(*pssh.SSHServerConnSession, *utils.FileEntry) (os.FileInfo, utils.ReadAndReaderAtCloser, error)
21 Delete(*pssh.SSHServerConnSession, *utils.FileEntry) error
22}
23
24type FileHandlerRouter struct {
25 FileMap map[string]ReadWriteHandler
26 Cfg *shared.ConfigSite
27 DBPool db.DB
28 Spaces []string
29}
30
31var _ utils.CopyFromClientHandler = &FileHandlerRouter{} // Verify implementation
32var _ utils.CopyFromClientHandler = (*FileHandlerRouter)(nil) // Verify implementation
33
34func NewFileHandlerRouter(cfg *shared.ConfigSite, dbpool db.DB, mapper map[string]ReadWriteHandler) *FileHandlerRouter {
35 return &FileHandlerRouter{
36 Cfg: cfg,
37 DBPool: dbpool,
38 FileMap: mapper,
39 Spaces: []string{cfg.Space},
40 }
41}
42
43func (r *FileHandlerRouter) findHandler(fp string) (ReadWriteHandler, error) {
44 fext := filepath.Ext(fp)
45 handler, ok := r.FileMap[fext]
46 if !ok {
47 hand, hasFallback := r.FileMap["fallback"]
48 if !hasFallback {
49 return nil, fmt.Errorf("no corresponding handler for file extension: %s", fext)
50 }
51 handler = hand
52 }
53 return handler, nil
54}
55
56func (r *FileHandlerRouter) Write(s *pssh.SSHServerConnSession, entry *utils.FileEntry) (string, error) {
57 if entry.Mode.IsDir() {
58 return "", os.ErrInvalid
59 }
60
61 handler, err := r.findHandler(entry.Filepath)
62 if err != nil {
63 return "", err
64 }
65 return handler.Write(s, entry)
66}
67
68func (r *FileHandlerRouter) Delete(s *pssh.SSHServerConnSession, entry *utils.FileEntry) error {
69 handler, err := r.findHandler(entry.Filepath)
70 if err != nil {
71 return err
72 }
73 return handler.Delete(s, entry)
74}
75
76func (r *FileHandlerRouter) Read(s *pssh.SSHServerConnSession, entry *utils.FileEntry) (os.FileInfo, utils.ReadAndReaderAtCloser, error) {
77 handler, err := r.findHandler(entry.Filepath)
78 if err != nil {
79 return nil, nil, err
80 }
81 return handler.Read(s, entry)
82}
83
84func (r *FileHandlerRouter) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
85 files := []os.FileInfo{}
86 for key, handler := range r.FileMap {
87 // TODO: hack because we have duplicate keys for .md and .css
88 if key == ".css" {
89 continue
90 }
91
92 ff, err := handler.List(s, fpath, isDir, recursive)
93 if err != nil {
94 r.GetLogger(s).Error("handler list", "err", err)
95 continue
96 }
97 files = append(files, ff...)
98 }
99 return files, nil
100}
101
102func (r *FileHandlerRouter) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
103 return pssh.GetLogger(s)
104}
105
106func (r *FileHandlerRouter) Validate(s *pssh.SSHServerConnSession) error {
107 logger := pssh.GetLogger(s)
108 user := pssh.GetUser(s)
109
110 if user == nil {
111 err := fmt.Errorf("could not get user from ctx")
112 logger.Error("error getting user from ctx", "err", err)
113 return err
114 }
115
116 logger.Info(
117 "attempting to upload files",
118 "user", user.Name,
119 "space", r.Cfg.Space,
120 )
121 return nil
122}
123
124func BaseList(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
125 var fileList []os.FileInfo
126 logger := pssh.GetLogger(s)
127 user := pssh.GetUser(s)
128
129 var err error
130
131 if user == nil {
132 err = fmt.Errorf("could not get user from ctx")
133 logger.Error("error getting user from ctx", "err", err)
134 return fileList, err
135 }
136
137 cleanFilename := filepath.Base(fpath)
138
139 var post *db.Post
140 var posts []*db.Post
141
142 if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
143 name := cleanFilename
144 if name == "" {
145 name = "/"
146 }
147
148 fileList = append(fileList, &utils.VirtualFile{
149 FName: name,
150 FIsDir: true,
151 })
152
153 for _, space := range spaces {
154 curPosts, e := dbpool.FindAllPostsForUser(user.ID, space)
155 if e != nil {
156 err = e
157 break
158 }
159 posts = append(posts, curPosts...)
160 }
161 } else {
162 for _, space := range spaces {
163 p, e := dbpool.FindPostWithFilename(cleanFilename, user.ID, space)
164 if e != nil {
165 err = e
166 continue
167 }
168 post = p
169 }
170
171 posts = append(posts, post)
172 }
173
174 if err != nil && !errors.Is(err, sql.ErrNoRows) {
175 return nil, err
176 }
177
178 for _, post := range posts {
179 if post == nil {
180 continue
181 }
182
183 fileList = append(fileList, &utils.VirtualFile{
184 FName: post.Filename,
185 FIsDir: false,
186 FSize: int64(post.FileSize),
187 FModTime: *post.UpdatedAt,
188 })
189 }
190
191 return fileList, nil
192}