- commit
- fa97ce3
- parent
- aa965ca
- author
- Antonio Mika
- date
- 2025-03-10 22:55:53 -0400 EDT
More refactoring
14 files changed,
+352,
-379
+0,
-92
1@@ -1,92 +0,0 @@
2-package main
3-
4-import (
5- "context"
6- "os"
7-
8- "github.com/picosh/pico/pgs"
9- pgsdb "github.com/picosh/pico/pgs/db"
10- "github.com/picosh/pico/shared"
11- "github.com/picosh/pico/shared/storage"
12- "github.com/picosh/send/auth"
13- "github.com/picosh/send/list"
14- "github.com/picosh/send/pipe"
15- "github.com/picosh/send/protocols/scp"
16- "github.com/picosh/utils"
17- "golang.org/x/crypto/ssh"
18-)
19-
20-func main() {
21- // Initialize the logger
22- logger := shared.CreateLogger("pgs-ssh")
23-
24- ctx, cancel := context.WithCancel(context.Background())
25- defer cancel()
26-
27- minioURL := utils.GetEnv("MINIO_URL", "")
28- minioUser := utils.GetEnv("MINIO_ROOT_USER", "")
29- minioPass := utils.GetEnv("MINIO_ROOT_PASSWORD", "")
30- dbURL := utils.GetEnv("DATABASE_URL", "")
31-
32- dbpool, err := pgsdb.NewDB(dbURL, logger)
33- if err != nil {
34- panic(err)
35- }
36-
37- st, err := storage.NewStorageMinio(logger, minioURL, minioUser, minioPass)
38- if err != nil {
39- panic(err)
40- }
41-
42- cfg := pgs.NewPgsConfig(logger, dbpool, st)
43-
44- sshAuth := shared.NewSshAuthHandler(cfg.DB, logger)
45-
46- cacheClearingQueue := make(chan string, 100)
47-
48- handler := pgs.NewUploadAssetHandler(
49- cfg,
50- cacheClearingQueue,
51- ctx,
52- )
53-
54- // Create a new SSH server
55- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{
56- ListenAddr: "localhost:2222",
57- ServerConfig: &ssh.ServerConfig{
58- PublicKeyCallback: sshAuth.PubkeyAuthHandler,
59- },
60- Middleware: []shared.SSHServerMiddleware{
61- pipe.Middleware(handler, ""),
62- list.Middleware(handler),
63- scp.Middleware(handler),
64- wishrsync.Middleware(handler),
65- auth.Middleware(handler),
66- wsh.PtyMdw(wsh.DeprecatedNotice()),
67- WishMiddleware(handler),
68- wsh.LogMiddleware(handler.GetLogger(s), handler.Cfg.DB),
69- },
70- })
71-
72- pemBytes, err := os.ReadFile("ssh_data/term_info_ed25519")
73- if err != nil {
74- logger.Error("failed to read private key file", "error", err)
75- return
76- }
77-
78- signer, err := ssh.ParsePrivateKey(pemBytes)
79- if err != nil {
80- logger.Error("failed to parse private key", "error", err)
81- return
82- }
83-
84- server.Config.AddHostKey(signer)
85-
86- err = server.ListenAndServe()
87- if err != nil {
88- logger.Error("failed to start SSH server", "error", err)
89- return
90- }
91-
92- logger.Info("SSH server started successfully")
93-}
+15,
-18
1@@ -5,34 +5,32 @@ import (
2 "text/tabwriter"
3 "time"
4
5- "github.com/charmbracelet/ssh"
6 "github.com/charmbracelet/wish"
7 "github.com/picosh/pico/db"
8+ "github.com/picosh/pico/pssh"
9 "github.com/picosh/pico/shared"
10-
11- wsh "github.com/picosh/pico/wish"
12 )
13
14-func WishMiddleware(dbpool db.DB, cfg *shared.ConfigSite) wish.Middleware {
15- return func(next ssh.Handler) ssh.Handler {
16- return func(sesh ssh.Session) {
17+func WishMiddleware(dbpool db.DB, cfg *shared.ConfigSite) pssh.SSHServerMiddleware {
18+ return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
19+ return func(sesh *pssh.SSHServerConnSession) error {
20 args := sesh.Command()
21 if len(args) == 0 {
22- next(sesh)
23- return
24+ return next(sesh)
25 }
26
27- logger := wsh.GetLogger(sesh)
28- user := wsh.GetUser(sesh)
29+ logger := pssh.GetLogger(sesh)
30+ user := pssh.GetUser(sesh)
31
32 if user == nil {
33- wish.Errorln(sesh, fmt.Errorf("user not found"))
34- return
35+ err := fmt.Errorf("user not found")
36+ fmt.Fprintln(sesh.Stderr(), err)
37+ return err
38 }
39
40 cmd := args[0]
41 if cmd == "help" {
42- wish.Printf(sesh, "Commands: [help, ls, rm, run]\n\n")
43+ fmt.Fprintf(sesh, "Commands: [help, ls, rm, run]\n\n")
44 writer := tabwriter.NewWriter(sesh, 0, 0, 1, ' ', tabwriter.TabIndent)
45 fmt.Fprintln(writer, "Cmd\tDesc")
46 fmt.Fprintf(
47@@ -55,17 +53,16 @@ func WishMiddleware(dbpool db.DB, cfg *shared.ConfigSite) wish.Middleware {
48 "%s\t%s\n",
49 "run {filename}", "runs the feed digest post immediately, ignoring last digest time validation",
50 )
51- writer.Flush()
52- return
53+ return writer.Flush()
54 } else if cmd == "ls" {
55 posts, err := dbpool.FindPostsForUser(&db.Pager{Page: 0, Num: 1000}, user.ID, "feeds")
56 if err != nil {
57- wish.Errorln(sesh, err)
58- return
59+ fmt.Fprintln(sesh.Stderr(), err)
60+ return err
61 }
62
63 if len(posts.Data) == 0 {
64- wish.Println(sesh, "no posts found")
65+ fmt.Fprintln(sesh, "no posts found")
66 }
67
68 writer := tabwriter.NewWriter(sesh, 0, 0, 1, ' ', tabwriter.TabIndent)
+15,
-16
1@@ -11,12 +11,11 @@ import (
2 "slices"
3 "strings"
4
5- "github.com/charmbracelet/ssh"
6 exifremove "github.com/neurosnap/go-exif-remove"
7 "github.com/picosh/pico/db"
8+ "github.com/picosh/pico/pssh"
9 "github.com/picosh/pico/shared"
10 "github.com/picosh/pico/shared/storage"
11- "github.com/picosh/pico/wish"
12 "github.com/picosh/pobj"
13 sst "github.com/picosh/pobj/storage"
14 sendutils "github.com/picosh/send/utils"
15@@ -53,11 +52,11 @@ func (h *UploadImgHandler) getObjectPath(fpath string) string {
16 return filepath.Join("prose", fpath)
17 }
18
19-func (h *UploadImgHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
20+func (h *UploadImgHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
21 var fileList []os.FileInfo
22
23- logger := wish.GetLogger(s)
24- user := wish.GetUser(s)
25+ logger := pssh.GetLogger(s)
26+ user := pssh.GetUser(s)
27
28 if user == nil {
29 err := fmt.Errorf("could not get user from ctx")
30@@ -102,9 +101,9 @@ func (h *UploadImgHandler) List(s ssh.Session, fpath string, isDir bool, recursi
31 return fileList, nil
32 }
33
34-func (h *UploadImgHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
35- logger := wish.GetLogger(s)
36- user := wish.GetUser(s)
37+func (h *UploadImgHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
38+ logger := pssh.GetLogger(s)
39+ user := pssh.GetUser(s)
40
41 if user == nil {
42 err := fmt.Errorf("could not get user from ctx")
43@@ -139,9 +138,9 @@ func (h *UploadImgHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.F
44 return fileInfo, reader, nil
45 }
46
47-func (h *UploadImgHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
48- logger := wish.GetLogger(s)
49- user := wish.GetUser(s)
50+func (h *UploadImgHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
51+ logger := pssh.GetLogger(s)
52+ user := pssh.GetUser(s)
53
54 if user == nil {
55 err := fmt.Errorf("could not get user from ctx")
56@@ -222,9 +221,9 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (str
57 return str, nil
58 }
59
60-func (h *UploadImgHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
61- logger := wish.GetLogger(s)
62- user := wish.GetUser(s)
63+func (h *UploadImgHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
64+ logger := pssh.GetLogger(s)
65+ user := pssh.GetUser(s)
66
67 if user == nil {
68 err := fmt.Errorf("could not get user from ctx")
69@@ -309,13 +308,13 @@ func (h *UploadImgHandler) metaImg(data *PostMetaData) error {
70 return nil
71 }
72
73-func (h *UploadImgHandler) writeImg(s ssh.Session, data *PostMetaData) error {
74+func (h *UploadImgHandler) writeImg(s *pssh.SSHServerConnSession, data *PostMetaData) error {
75 valid, err := h.validateImg(data)
76 if !valid {
77 return err
78 }
79
80- logger := wish.GetLogger(s)
81+ logger := pssh.GetLogger(s)
82 logger = logger.With(
83 "filename", data.Filename,
84 )
+13,
-14
1@@ -10,10 +10,9 @@ import (
2 "strings"
3 "time"
4
5- "github.com/charmbracelet/ssh"
6 "github.com/picosh/pico/db"
7+ "github.com/picosh/pico/pssh"
8 "github.com/picosh/pico/shared"
9- "github.com/picosh/pico/wish"
10 sendutils "github.com/picosh/send/utils"
11 "github.com/picosh/utils"
12 )
13@@ -28,8 +27,8 @@ type PostMetaData struct {
14 }
15
16 type ScpFileHooks interface {
17- FileValidate(s ssh.Session, data *PostMetaData) (bool, error)
18- FileMeta(s ssh.Session, data *PostMetaData) error
19+ FileValidate(s *pssh.SSHServerConnSession, data *PostMetaData) (bool, error)
20+ FileMeta(s *pssh.SSHServerConnSession, data *PostMetaData) error
21 }
22
23 type ScpUploadHandler struct {
24@@ -46,13 +45,13 @@ func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks)
25 }
26 }
27
28-func (r *ScpUploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
29+func (r *ScpUploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
30 return BaseList(s, fpath, isDir, recursive, []string{r.Cfg.Space}, r.DBPool)
31 }
32
33-func (h *ScpUploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
34- logger := wish.GetLogger(s)
35- user := wish.GetUser(s)
36+func (h *ScpUploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
37+ logger := pssh.GetLogger(s)
38+ user := pssh.GetUser(s)
39
40 if user == nil {
41 err := fmt.Errorf("could not get user from ctx")
42@@ -83,9 +82,9 @@ func (h *ScpUploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.F
43 return fileInfo, reader, nil
44 }
45
46-func (h *ScpUploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
47- logger := wish.GetLogger(s)
48- user := wish.GetUser(s)
49+func (h *ScpUploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
50+ logger := pssh.GetLogger(s)
51+ user := pssh.GetUser(s)
52
53 if user == nil {
54 err := fmt.Errorf("could not get user from ctx")
55@@ -271,9 +270,9 @@ func (h *ScpUploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (str
56 return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
57 }
58
59-func (h *ScpUploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
60- logger := wish.GetLogger(s)
61- user := wish.GetUser(s)
62+func (h *ScpUploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
63+ logger := pssh.GetLogger(s)
64+ user := pssh.GetUser(s)
65
66 if user == nil {
67 err := fmt.Errorf("could not get user from ctx")
+17,
-18
1@@ -8,18 +8,17 @@ import (
2 "os"
3 "path/filepath"
4
5- "github.com/charmbracelet/ssh"
6 "github.com/picosh/pico/db"
7+ "github.com/picosh/pico/pssh"
8 "github.com/picosh/pico/shared"
9- "github.com/picosh/pico/wish"
10 "github.com/picosh/send/utils"
11 )
12
13 type ReadWriteHandler interface {
14- List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error)
15- Write(ssh.Session, *utils.FileEntry) (string, error)
16- Read(ssh.Session, *utils.FileEntry) (os.FileInfo, utils.ReadAndReaderAtCloser, error)
17- Delete(ssh.Session, *utils.FileEntry) error
18+ List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error)
19+ Write(*pssh.SSHServerConnSession, *utils.FileEntry) (string, error)
20+ Read(*pssh.SSHServerConnSession, *utils.FileEntry) (os.FileInfo, utils.ReadAndReaderAtCloser, error)
21+ Delete(*pssh.SSHServerConnSession, *utils.FileEntry) error
22 }
23
24 type FileHandlerRouter struct {
25@@ -54,7 +53,7 @@ func (r *FileHandlerRouter) findHandler(fp string) (ReadWriteHandler, error) {
26 return handler, nil
27 }
28
29-func (r *FileHandlerRouter) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
30+func (r *FileHandlerRouter) Write(s *pssh.SSHServerConnSession, entry *utils.FileEntry) (string, error) {
31 if entry.Mode.IsDir() {
32 return "", os.ErrInvalid
33 }
34@@ -66,7 +65,7 @@ func (r *FileHandlerRouter) Write(s ssh.Session, entry *utils.FileEntry) (string
35 return handler.Write(s, entry)
36 }
37
38-func (r *FileHandlerRouter) Delete(s ssh.Session, entry *utils.FileEntry) error {
39+func (r *FileHandlerRouter) Delete(s *pssh.SSHServerConnSession, entry *utils.FileEntry) error {
40 handler, err := r.findHandler(entry.Filepath)
41 if err != nil {
42 return err
43@@ -74,7 +73,7 @@ func (r *FileHandlerRouter) Delete(s ssh.Session, entry *utils.FileEntry) error
44 return handler.Delete(s, entry)
45 }
46
47-func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReadAndReaderAtCloser, error) {
48+func (r *FileHandlerRouter) Read(s *pssh.SSHServerConnSession, entry *utils.FileEntry) (os.FileInfo, utils.ReadAndReaderAtCloser, error) {
49 handler, err := r.findHandler(entry.Filepath)
50 if err != nil {
51 return nil, nil, err
52@@ -82,7 +81,7 @@ func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.File
53 return handler.Read(s, entry)
54 }
55
56-func (r *FileHandlerRouter) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
57+func (r *FileHandlerRouter) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
58 files := []os.FileInfo{}
59 for key, handler := range r.FileMap {
60 // TODO: hack because we have duplicate keys for .md and .css
61@@ -100,13 +99,13 @@ func (r *FileHandlerRouter) List(s ssh.Session, fpath string, isDir bool, recurs
62 return files, nil
63 }
64
65-func (r *FileHandlerRouter) GetLogger(s ssh.Session) *slog.Logger {
66- return wish.GetLogger(s)
67+func (r *FileHandlerRouter) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
68+ return pssh.GetLogger(s)
69 }
70
71-func (r *FileHandlerRouter) Validate(s ssh.Session) error {
72- logger := wish.GetLogger(s)
73- user := wish.GetUser(s)
74+func (r *FileHandlerRouter) Validate(s *pssh.SSHServerConnSession) error {
75+ logger := pssh.GetLogger(s)
76+ user := pssh.GetUser(s)
77
78 if user == nil {
79 err := fmt.Errorf("could not get user from ctx")
80@@ -122,10 +121,10 @@ func (r *FileHandlerRouter) Validate(s ssh.Session) error {
81 return nil
82 }
83
84-func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
85+func BaseList(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
86 var fileList []os.FileInfo
87- logger := wish.GetLogger(s)
88- user := wish.GetUser(s)
89+ logger := pssh.GetLogger(s)
90+ user := pssh.GetUser(s)
91
92 var err error
93
+39,
-42
1@@ -6,17 +6,14 @@ import (
2 "slices"
3 "strings"
4
5- "github.com/charmbracelet/ssh"
6- "github.com/charmbracelet/wish"
7- bm "github.com/charmbracelet/wish/bubbletea"
8- "github.com/muesli/termenv"
9 "github.com/picosh/pico/db"
10 pgsdb "github.com/picosh/pico/pgs/db"
11+ "github.com/picosh/pico/pssh"
12 sendutils "github.com/picosh/send/utils"
13 "github.com/picosh/utils"
14 )
15
16-func getUser(s ssh.Session, dbpool pgsdb.PgsDB) (*db.User, error) {
17+func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
18 if s.PublicKey() == nil {
19 return nil, fmt.Errorf("key not found")
20 }
21@@ -46,7 +43,7 @@ func (i *arrayFlags) Set(value string) error {
22 return nil
23 }
24
25-func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
26+func flagSet(cmdName string, sesh *pssh.SSHServerConnSession) (*flag.FlagSet, *bool) {
27 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
28 cmd.SetOutput(sesh)
29 write := cmd.Bool("write", false, "apply changes")
30@@ -63,18 +60,17 @@ func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
31 return true
32 }
33
34-func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
35+func Middleware(handler *UploadAssetHandler) pssh.SSHServerMiddleware {
36 dbpool := handler.Cfg.DB
37 log := handler.Cfg.Logger
38 cfg := handler.Cfg
39 store := handler.Cfg.Storage
40
41- return func(next ssh.Handler) ssh.Handler {
42- return func(sesh ssh.Session) {
43+ return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
44+ return func(sesh *pssh.SSHServerConnSession) error {
45 args := sesh.Command()
46 if len(args) == 0 {
47- next(sesh)
48- return
49+ return next(sesh)
50 }
51
52 // default width and height when no pty
53@@ -89,11 +85,12 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
54 user, err := getUser(sesh, dbpool)
55 if err != nil {
56 sendutils.ErrorHandler(sesh, err)
57- return
58+ return err
59 }
60
61- renderer := bm.MakeRenderer(sesh)
62- renderer.SetColorProfile(termenv.TrueColor)
63+ // renderer := bm.MakeRenderer(sesh)
64+ // renderer.SetColorProfile(termenv.TrueColor)
65+ // styles := common.DefaultStyles(renderer)
66
67 opts := Cmd{
68 Session: sesh,
69@@ -102,33 +99,33 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
70 Log: log,
71 Dbpool: dbpool,
72 Write: false,
73- Width: width,
74- Height: height,
75- Cfg: handler.Cfg,
76+ // Styles: styles,
77+ Width: width,
78+ Height: height,
79+ Cfg: handler.Cfg,
80 }
81
82 cmd := strings.TrimSpace(args[0])
83 if len(args) == 1 {
84 if cmd == "help" {
85 opts.help()
86- return
87+ return nil
88 } else if cmd == "stats" {
89 err := opts.stats(cfg.MaxSize)
90 opts.bail(err)
91- return
92+ return err
93 } else if cmd == "ls" {
94 err := opts.ls()
95 opts.bail(err)
96- return
97+ return err
98 } else if cmd == "cache-all" {
99 opts.Write = true
100 err := opts.cacheAll()
101 opts.notice()
102 opts.bail(err)
103- return
104+ return err
105 } else {
106- next(sesh)
107- return
108+ return next(sesh)
109 }
110 }
111
112@@ -145,12 +142,12 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
113 if cmd == "fzf" {
114 err := opts.fzf(projectName)
115 opts.bail(err)
116- return
117+ return err
118 } else if cmd == "link" {
119 linkCmd, write := flagSet("link", sesh)
120 linkTo := linkCmd.String("to", "", "symbolic link to this project")
121 if !flagCheck(linkCmd, projectName, cmdArgs) {
122- return
123+ return nil
124 }
125 opts.Write = *write
126
127@@ -159,7 +156,7 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
128 "must provide `--to` flag",
129 )
130 opts.bail(err)
131- return
132+ return err
133 }
134
135 err := opts.link(projectName, *linkTo)
136@@ -167,67 +164,67 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
137 if err != nil {
138 opts.bail(err)
139 }
140- return
141+ return err
142 } else if cmd == "unlink" {
143 unlinkCmd, write := flagSet("unlink", sesh)
144 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
145- return
146+ return nil
147 }
148 opts.Write = *write
149
150 err := opts.unlink(projectName)
151 opts.notice()
152 opts.bail(err)
153- return
154+ return err
155 } else if cmd == "depends" {
156 err := opts.depends(projectName)
157 opts.bail(err)
158- return
159+ return err
160 } else if cmd == "retain" {
161 retainCmd, write := flagSet("retain", sesh)
162 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
163 if !flagCheck(retainCmd, projectName, cmdArgs) {
164- return
165+ return nil
166 }
167 opts.Write = *write
168
169 err := opts.prune(projectName, *retainNum)
170 opts.notice()
171 opts.bail(err)
172- return
173+ return err
174 } else if cmd == "prune" {
175 pruneCmd, write := flagSet("prune", sesh)
176 if !flagCheck(pruneCmd, projectName, cmdArgs) {
177- return
178+ return nil
179 }
180 opts.Write = *write
181
182 err := opts.prune(projectName, 0)
183 opts.notice()
184 opts.bail(err)
185- return
186+ return err
187 } else if cmd == "rm" {
188 rmCmd, write := flagSet("rm", sesh)
189 if !flagCheck(rmCmd, projectName, cmdArgs) {
190- return
191+ return nil
192 }
193 opts.Write = *write
194
195 err := opts.rm(projectName)
196 opts.notice()
197 opts.bail(err)
198- return
199+ return err
200 } else if cmd == "cache" {
201 cacheCmd, write := flagSet("cache", sesh)
202 if !flagCheck(cacheCmd, projectName, cmdArgs) {
203- return
204+ return nil
205 }
206 opts.Write = *write
207
208 err := opts.cache(projectName)
209 opts.notice()
210 opts.bail(err)
211- return
212+ return err
213 } else if cmd == "acl" {
214 aclCmd, write := flagSet("acl", sesh)
215 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
216@@ -238,7 +235,7 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
217 "list of pico usernames or sha256 public keys, delimited by commas",
218 )
219 if !flagCheck(aclCmd, projectName, cmdArgs) {
220- return
221+ return nil
222 }
223 opts.Write = *write
224
225@@ -248,15 +245,15 @@ func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
226 *aclType,
227 )
228 opts.bail(err)
229- return
230+ return err
231 }
232
233 err := opts.acl(projectName, *aclType, acls)
234 opts.notice()
235 opts.bail(err)
236+ return err
237 } else {
238- next(sesh)
239- return
240+ return next(sesh)
241 }
242 }
243 }
+45,
-74
1@@ -2,71 +2,30 @@ package pgs
2
3 import (
4 "context"
5- "fmt"
6 "os"
7 "os/signal"
8 "syscall"
9- "time"
10
11- "github.com/charmbracelet/promwish"
12- "github.com/charmbracelet/ssh"
13- "github.com/charmbracelet/wish"
14- wsh "github.com/picosh/pico/wish"
15+ "github.com/picosh/pico/pssh"
16+ "github.com/picosh/pico/shared"
17 "github.com/picosh/send/auth"
18 "github.com/picosh/send/list"
19 "github.com/picosh/send/pipe"
20- wishrsync "github.com/picosh/send/protocols/rsync"
21+ "github.com/picosh/send/protocols/rsync"
22 "github.com/picosh/send/protocols/scp"
23 "github.com/picosh/send/protocols/sftp"
24- "github.com/picosh/send/proxy"
25- "github.com/picosh/tunkit"
26 "github.com/picosh/utils"
27+ "golang.org/x/crypto/ssh"
28 )
29
30-func createRouter(handler *UploadAssetHandler) proxy.Router {
31- return func(sh ssh.Handler, s ssh.Session) []wish.Middleware {
32- return []wish.Middleware{
33- pipe.Middleware(handler, ""),
34- list.Middleware(handler),
35- scp.Middleware(handler),
36- wishrsync.Middleware(handler),
37- auth.Middleware(handler),
38- wsh.PtyMdw(wsh.DeprecatedNotice()),
39- WishMiddleware(handler),
40- wsh.LogMiddleware(handler.GetLogger(s), handler.Cfg.DB),
41- }
42- }
43-}
44-
45-func withProxy(handler *UploadAssetHandler, otherMiddleware ...wish.Middleware) ssh.Option {
46- return func(server *ssh.Server) error {
47- err := sftp.SSHOption(handler)(server)
48- if err != nil {
49- return err
50- }
51-
52- newSubsystemHandlers := map[string]ssh.SubsystemHandler{}
53-
54- for name, subsystemHandler := range server.SubsystemHandlers {
55- newSubsystemHandlers[name] = func(s ssh.Session) {
56- wsh.LogMiddleware(handler.GetLogger(s), handler.Cfg.DB)(ssh.Handler(subsystemHandler))(s)
57- }
58- }
59-
60- server.SubsystemHandlers = newSubsystemHandlers
61-
62- return proxy.WithProxy(createRouter(handler), otherMiddleware...)(server)
63- }
64-}
65-
66 func StartSshServer(cfg *PgsConfig, killCh chan error) {
67 host := utils.GetEnv("PGS_HOST", "0.0.0.0")
68 port := utils.GetEnv("PGS_SSH_PORT", "2222")
69- promPort := utils.GetEnv("PGS_PROM_PORT", "9222")
70+ // promPort := utils.GetEnv("PGS_PROM_PORT", "9222")
71 logger := cfg.Logger
72
73- ctx := context.Background()
74- defer ctx.Done()
75+ ctx, cancel := context.WithCancel(context.Background())
76+ defer cancel()
77
78 cacheClearingQueue := make(chan string, 100)
79 handler := NewUploadAssetHandler(
80@@ -75,48 +34,60 @@ func StartSshServer(cfg *PgsConfig, killCh chan error) {
81 ctx,
82 )
83
84- webTunnel := &tunkit.WebTunnelHandler{
85- Logger: logger,
86- HttpHandler: createHttpHandler(cfg),
87+ sshAuth := shared.NewSshAuthHandler(cfg.DB, logger)
88+
89+ // webTunnel := &tunkit.WebTunnelHandler{
90+ // Logger: logger,
91+ // HttpHandler: createHttpHandler(cfg),
92+ // }
93+
94+ // Create a new SSH server
95+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{
96+ ListenAddr: "localhost:2222",
97+ ServerConfig: &ssh.ServerConfig{
98+ PublicKeyCallback: sshAuth.PubkeyAuthHandler,
99+ },
100+ Middleware: []pssh.SSHServerMiddleware{
101+ pipe.Middleware(handler, ""),
102+ list.Middleware(handler),
103+ sftp.Middleware(handler),
104+ scp.Middleware(handler),
105+ rsync.Middleware(handler),
106+ auth.Middleware(handler),
107+ pssh.PtyMdw(pssh.DeprecatedNotice()),
108+ Middleware(handler),
109+ pssh.LogMiddleware(handler, handler.Cfg.DB),
110+ },
111+ })
112+
113+ pemBytes, err := os.ReadFile("ssh_data/term_info_ed25519")
114+ if err != nil {
115+ logger.Error("failed to read private key file", "error", err)
116+ return
117 }
118
119- // sshAuth := shared.NewSshAuthHandler(cfg.DB, logger)
120- s, err := wish.NewServer(
121- // wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
122- // wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
123- // wish.WithPublicKeyAuth(sshAuth.PubkeyAuthHandler),
124- tunkit.WithWebTunnel(webTunnel),
125- withProxy(
126- handler,
127- promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "pgs-ssh"),
128- ),
129- )
130+ signer, err := ssh.ParsePrivateKey(pemBytes)
131 if err != nil {
132- logger.Error(err.Error())
133+ logger.Error("failed to parse private key", "error", err)
134 return
135 }
136
137+ server.Config.AddHostKey(signer)
138+
139 done := make(chan os.Signal, 1)
140 signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
141 logger.Info("starting SSH server on", "host", host, "port", port)
142
143 go func() {
144- if err = s.ListenAndServe(); err != nil {
145- if err != ssh.ErrServerClosed {
146- logger.Error("serve", "err", err.Error())
147- os.Exit(1)
148- }
149+ if err = server.ListenAndServe(); err != nil {
150+ logger.Error("serve", "err", err.Error())
151+ os.Exit(1)
152 }
153 }()
154
155 exit := func() {
156 logger.Info("stopping ssh server")
157- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
158- defer func() { cancel() }()
159- if err := s.Shutdown(ctx); err != nil {
160- logger.Error("shutdown", "err", err.Error())
161- os.Exit(1)
162- }
163+ cancel()
164 }
165
166 select {
+35,
-34
1@@ -15,12 +15,10 @@ import (
2 "sync"
3 "time"
4
5- "github.com/charmbracelet/ssh"
6- "github.com/charmbracelet/wish"
7 "github.com/picosh/pico/db"
8 pgsdb "github.com/picosh/pico/pgs/db"
9+ "github.com/picosh/pico/pssh"
10 "github.com/picosh/pico/shared"
11- wsh "github.com/picosh/pico/wish"
12 "github.com/picosh/pobj"
13 sst "github.com/picosh/pobj/storage"
14 sendutils "github.com/picosh/send/utils"
15@@ -37,7 +35,7 @@ type DenyList struct {
16 Denylist string
17 }
18
19-func getDenylist(s ssh.Session) *DenyList {
20+func getDenylist(s *pssh.SSHServerConnSession) *DenyList {
21 v := s.Context().Value(ctxDenylistKey{})
22 if v == nil {
23 return nil
24@@ -46,11 +44,11 @@ func getDenylist(s ssh.Session) *DenyList {
25 return denylist
26 }
27
28-func setDenylist(s ssh.Session, denylist string) {
29- s.Context().SetValue(ctxDenylistKey{}, &DenyList{Denylist: denylist})
30+func setDenylist(s *pssh.SSHServerConnSession, denylist string) {
31+ s.SetValue(ctxDenylistKey{}, &DenyList{Denylist: denylist})
32 }
33
34-func getProject(s ssh.Session) *db.Project {
35+func getProject(s *pssh.SSHServerConnSession) *db.Project {
36 v := s.Context().Value(ctxProjectKey{})
37 if v == nil {
38 return nil
39@@ -59,11 +57,11 @@ func getProject(s ssh.Session) *db.Project {
40 return project
41 }
42
43-func setProject(s ssh.Session, project *db.Project) {
44- s.Context().SetValue(ctxProjectKey{}, project)
45+func setProject(s *pssh.SSHServerConnSession, project *db.Project) {
46+ s.SetValue(ctxProjectKey{}, project)
47 }
48
49-func getBucket(s ssh.Session) (sst.Bucket, error) {
50+func getBucket(s *pssh.SSHServerConnSession) (sst.Bucket, error) {
51 bucket := s.Context().Value(ctxBucketKey{}).(sst.Bucket)
52 if bucket.Name == "" {
53 return bucket, fmt.Errorf("bucket not set on `ssh.Context()` for connection")
54@@ -71,11 +69,11 @@ func getBucket(s ssh.Session) (sst.Bucket, error) {
55 return bucket, nil
56 }
57
58-func getStorageSize(s ssh.Session) uint64 {
59+func getStorageSize(s *pssh.SSHServerConnSession) uint64 {
60 return s.Context().Value(ctxStorageSizeKey{}).(uint64)
61 }
62
63-func incrementStorageSize(s ssh.Session, fileSize int64) uint64 {
64+func incrementStorageSize(s *pssh.SSHServerConnSession, fileSize int64) uint64 {
65 curSize := getStorageSize(s)
66 var nextStorageSize uint64
67 if fileSize < 0 {
68@@ -83,7 +81,7 @@ func incrementStorageSize(s ssh.Session, fileSize int64) uint64 {
69 } else {
70 nextStorageSize = curSize + uint64(fileSize)
71 }
72- s.Context().SetValue(ctxStorageSizeKey{}, nextStorageSize)
73+ s.SetValue(ctxStorageSizeKey{}, nextStorageSize)
74 return nextStorageSize
75 }
76
77@@ -113,13 +111,13 @@ func NewUploadAssetHandler(cfg *PgsConfig, ch chan string, ctx context.Context)
78 }
79 }
80
81-func (h *UploadAssetHandler) GetLogger(s ssh.Session) *slog.Logger {
82- return wsh.GetLogger(s)
83+func (h *UploadAssetHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
84+ return pssh.GetLogger(s)
85 }
86
87-func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
88- logger := wsh.GetLogger(s)
89- user := wsh.GetUser(s)
90+func (h *UploadAssetHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
91+ logger := pssh.GetLogger(s)
92+ user := pssh.GetUser(s)
93
94 if user == nil {
95 err := fmt.Errorf("could not get user from ctx")
96@@ -153,11 +151,11 @@ func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os
97 return fileInfo, reader, nil
98 }
99
100-func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
101+func (h *UploadAssetHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
102 var fileList []os.FileInfo
103
104- logger := wsh.GetLogger(s)
105- user := wsh.GetUser(s)
106+ logger := pssh.GetLogger(s)
107+ user := pssh.GetUser(s)
108
109 if user == nil {
110 err := fmt.Errorf("could not get user from ctx")
111@@ -201,9 +199,9 @@ func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recur
112 return fileList, nil
113 }
114
115-func (h *UploadAssetHandler) Validate(s ssh.Session) error {
116- logger := wsh.GetLogger(s)
117- user := wsh.GetUser(s)
118+func (h *UploadAssetHandler) Validate(s *pssh.SSHServerConnSession) error {
119+ logger := pssh.GetLogger(s)
120+ user := pssh.GetUser(s)
121
122 if user == nil {
123 err := fmt.Errorf("could not get user from ctx")
124@@ -217,14 +215,14 @@ func (h *UploadAssetHandler) Validate(s ssh.Session) error {
125 return err
126 }
127
128- s.Context().SetValue(ctxBucketKey{}, bucket)
129+ s.SetValue(ctxBucketKey{}, bucket)
130
131 totalStorageSize, err := h.Cfg.Storage.GetBucketQuota(bucket)
132 if err != nil {
133 return err
134 }
135
136- s.Context().SetValue(ctxStorageSizeKey{}, totalStorageSize)
137+ s.SetValue(ctxStorageSizeKey{}, totalStorageSize)
138
139 logger.Info(
140 "bucket size",
141@@ -279,9 +277,9 @@ func findPlusFF(dbpool pgsdb.PgsDB, cfg *PgsConfig, userID string) *db.FeatureFl
142 return ff
143 }
144
145-func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
146- logger := wsh.GetLogger(s)
147- user := wsh.GetUser(s)
148+func (h *UploadAssetHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
149+ logger := pssh.GetLogger(s)
150+ user := pssh.GetUser(s)
151
152 if user == nil {
153 err := fmt.Errorf("could not get user from ctx")
154@@ -380,7 +378,10 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (s
155 remaining := int64(storageMax) - int64(curStorageSize)
156 sizeRemaining := min(remaining+curFileSize, fileMax)
157 if sizeRemaining <= 0 {
158- wish.Fatalln(s, "storage quota reached")
159+ fmt.Fprintln(s.Stderr(), "storage quota reached")
160+ fmt.Fprintf(s.Stderr(), "\r")
161+ _ = s.Exit(1)
162+ _ = s.Close()
163 return "", fmt.Errorf("storage quota reached")
164 }
165 logger = logger.With(
166@@ -442,9 +443,9 @@ func isSpecialFile(entry *sendutils.FileEntry) bool {
167 return fname == "_headers" || fname == "_redirects"
168 }
169
170-func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
171- logger := wsh.GetLogger(s)
172- user := wsh.GetUser(s)
173+func (h *UploadAssetHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
174+ logger := pssh.GetLogger(s)
175+ user := pssh.GetUser(s)
176
177 if user == nil {
178 err := fmt.Errorf("could not get user from ctx")
179@@ -537,7 +538,7 @@ func (h *UploadAssetHandler) validateAsset(data *FileData) (bool, error) {
180 return true, nil
181 }
182
183-func (h *UploadAssetHandler) writeAsset(s ssh.Session, reader io.Reader, data *FileData) (int64, error) {
184+func (h *UploadAssetHandler) writeAsset(s *pssh.SSHServerConnSession, reader io.Reader, data *FileData) (int64, error) {
185 assetFilepath := shared.GetAssetFileName(data.FileEntry)
186
187 logger := h.GetLogger(s)
+21,
-22
1@@ -12,10 +12,9 @@ import (
2 "time"
3
4 "github.com/charmbracelet/ssh"
5- "github.com/charmbracelet/wish"
6 "github.com/picosh/pico/db"
7+ "github.com/picosh/pico/pssh"
8 "github.com/picosh/pico/shared"
9- wsh "github.com/picosh/pico/wish"
10 sendutils "github.com/picosh/send/utils"
11 "github.com/picosh/utils"
12 )
13@@ -52,13 +51,13 @@ func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualF
14 return fileInfo, text, nil
15 }
16
17-func (h *UploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
18+func (h *UploadHandler) Delete(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) error {
19 return errors.New("unsupported")
20 }
21
22-func (h *UploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
23- logger := wsh.GetLogger(s)
24- user := wsh.GetUser(s)
25+func (h *UploadHandler) Read(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReadAndReaderAtCloser, error) {
26+ logger := pssh.GetLogger(s)
27+ user := pssh.GetUser(s)
28
29 if user == nil {
30 err := fmt.Errorf("could not get user from ctx")
31@@ -84,11 +83,11 @@ func (h *UploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.File
32 return nil, nil, os.ErrNotExist
33 }
34
35-func (h *UploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
36+func (h *UploadHandler) List(s *pssh.SSHServerConnSession, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
37 var fileList []os.FileInfo
38
39- logger := wsh.GetLogger(s)
40- user := wsh.GetUser(s)
41+ logger := pssh.GetLogger(s)
42+ user := pssh.GetUser(s)
43
44 if user == nil {
45 err := fmt.Errorf("could not get user from ctx")
46@@ -127,11 +126,11 @@ func (h *UploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive
47 return fileList, nil
48 }
49
50-func (h *UploadHandler) GetLogger(s ssh.Session) *slog.Logger {
51- return wsh.GetLogger(s)
52+func (h *UploadHandler) GetLogger(s *pssh.SSHServerConnSession) *slog.Logger {
53+ return pssh.GetLogger(s)
54 }
55
56-func (h *UploadHandler) Validate(s ssh.Session) error {
57+func (h *UploadHandler) Validate(s *pssh.SSHServerConnSession) error {
58 var err error
59 key, err := sendutils.KeyText(s)
60 if err != nil {
61@@ -212,7 +211,7 @@ func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []
62 }
63 }
64
65-func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s ssh.Session) error {
66+func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s *pssh.SSHServerConnSession) error {
67 logger.Info("processing new authorized_keys")
68 dbpool := h.DBPool
69
70@@ -245,12 +244,12 @@ func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger,
71 for _, pk := range diff.Add {
72 key := utils.KeyForKeyText(pk.Pk)
73
74- wish.Errorf(s, "adding pubkey (%s)\n", key)
75+ fmt.Fprintf(s.Stderr(), "adding pubkey (%s)\n", key)
76 logger.Info("adding pubkey", "pubkey", key)
77
78 err = dbpool.InsertPublicKey(user.ID, key, pk.Comment, nil)
79 if err != nil {
80- wish.Errorf(s, "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
81+ fmt.Fprintf(s.Stderr(), "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
82 logger.Error("could not insert pubkey", "err", err.Error())
83 }
84 }
85@@ -258,7 +257,7 @@ func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger,
86 for _, pk := range diff.Update {
87 key := utils.KeyForKeyText(pk.Pk)
88
89- wish.Errorf(s, "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
90+ fmt.Fprintf(s.Stderr(), "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
91 logger.Info(
92 "updating pubkey with comment",
93 "pubkey", key,
94@@ -267,18 +266,18 @@ func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger,
95
96 _, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
97 if err != nil {
98- wish.Errorf(s, "error: could not update pubkey: %s (%s)\n", err.Error(), key)
99+ fmt.Fprintf(s.Stderr(), "error: could not update pubkey: %s (%s)\n", err.Error(), key)
100 logger.Error("could not update pubkey", "err", err.Error(), "key", key)
101 }
102 }
103
104 if len(diff.Rm) > 0 {
105- wish.Errorf(s, "removing pubkeys: %s\n", diff.Rm)
106+ fmt.Fprintf(s.Stderr(), "removing pubkeys: %s\n", diff.Rm)
107 logger.Info("removing pubkeys", "pubkeys", diff.Rm)
108
109 err = dbpool.RemoveKeys(diff.Rm)
110 if err != nil {
111- wish.Errorf(s, "error: could not rm pubkeys: %s\n", err.Error())
112+ fmt.Fprintf(s.Stderr(), "error: could not rm pubkeys: %s\n", err.Error())
113 logger.Error("could not remove pubkey", "err", err.Error())
114 }
115 }
116@@ -286,9 +285,9 @@ func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger,
117 return nil
118 }
119
120-func (h *UploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
121- logger := wsh.GetLogger(s)
122- user := wsh.GetUser(s)
123+func (h *UploadHandler) Write(s *pssh.SSHServerConnSession, entry *sendutils.FileEntry) (string, error) {
124+ logger := pssh.GetLogger(s)
125+ user := pssh.GetUser(s)
126
127 if user == nil {
128 err := fmt.Errorf("could not get user from ctx")
+20,
-18
1@@ -17,8 +17,8 @@ import (
2 "github.com/charmbracelet/wish"
3 "github.com/google/uuid"
4 "github.com/picosh/pico/db"
5+ "github.com/picosh/pico/pssh"
6 "github.com/picosh/pico/shared"
7- wsh "github.com/picosh/pico/wish"
8 psub "github.com/picosh/pubsub"
9 gossh "golang.org/x/crypto/ssh"
10 )
11@@ -134,21 +134,20 @@ func checkAccess(accessList []string, userName string, sesh ssh.Session) bool {
12 return false
13 }
14
15-func WishMiddleware(handler *CliHandler) wish.Middleware {
16+func WishMiddleware(handler *CliHandler) pssh.SSHServerMiddleware {
17 pubsub := handler.PubSub
18
19- return func(next ssh.Handler) ssh.Handler {
20- return func(sesh ssh.Session) {
21+ return func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
22+ return func(sesh *pssh.SSHServerConnSession) error {
23 ctx := sesh.Context()
24- logger := wsh.GetLogger(sesh)
25- user := wsh.GetUser(sesh)
26+ logger := pssh.GetLogger(sesh)
27+ user := pssh.GetUser(sesh)
28
29 args := sesh.Command()
30
31 if len(args) == 0 {
32- wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
33- next(sesh)
34- return
35+ fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
36+ return next(sesh)
37 }
38
39 userName := "public"
40@@ -188,13 +187,16 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
41
42 cmd := strings.TrimSpace(args[0])
43 if cmd == "help" {
44- wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
45- next(sesh)
46- return
47+ fmt.Fprintln(sesh, helpStr(toSshCmd(handler.Cfg)))
48+ return next(sesh)
49 } else if cmd == "ls" {
50 if userName == "public" {
51- wish.Fatalln(sesh, "access denied")
52- return
53+ err := fmt.Errorf("access denied")
54+ fmt.Fprintln(sesh.Stderr(), err)
55+ fmt.Fprintf(sesh.Stderr(), "\r")
56+ _ = sesh.Exit(1)
57+ _ = sesh.Close()
58+ return err
59 }
60
61 topicFilter := fmt.Sprintf("%s/", userName)
62@@ -221,7 +223,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
63 }
64
65 if len(channels) == 0 && len(waitingChannels) == 0 {
66- wish.Println(sesh, "no pubsub channels found")
67+ fmt.Fprintln(sesh, "no pubsub channels found")
68 } else {
69 var outputData string
70 if len(channels) > 0 || len(waitingChannels) > 0 {
71@@ -415,7 +417,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
72 }
73
74 if !*clean {
75- wish.Println(sesh, termMsg)
76+ fmt.Fprintln(sesh, termMsg)
77 }
78
79 ready := make(chan struct{})
80@@ -486,7 +488,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
81 }
82
83 if !*clean {
84- wish.Println(sesh, "sending msg ...")
85+ fmt.Fprintln(sesh, "sending msg ...")
86 }
87
88 err = pubsub.Pub(
89@@ -500,7 +502,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
90 )
91
92 if !*clean {
93- wish.Println(sesh, "msg sent!")
94+ fmt.Fprintln(sesh, "msg sent!")
95 }
96
97 if err != nil && !*clean {
R wish/logger.go =>
pssh/logger.go
+19,
-14
1@@ -1,13 +1,10 @@
2-package wish
3+package pssh
4
5 import (
6 "log/slog"
7 "time"
8
9- "github.com/charmbracelet/ssh"
10- "github.com/charmbracelet/wish"
11 "github.com/picosh/pico/db"
12- "github.com/picosh/pico/shared"
13 )
14
15 type ctxLoggerKey struct{}
16@@ -17,27 +14,33 @@ type FindUserInterface interface {
17 FindUserByPubkey(string) (*db.User, error)
18 }
19
20-func LogMiddleware(defaultLogger *slog.Logger, db FindUserInterface) wish.Middleware {
21- return func(sh ssh.Handler) ssh.Handler {
22- return func(s ssh.Session) {
23+type GetLoggerInterface interface {
24+ GetLogger(s *SSHServerConnSession) *slog.Logger
25+}
26+
27+func LogMiddleware(getLogger GetLoggerInterface, db FindUserInterface) SSHServerMiddleware {
28+ return func(sshHandler SSHServerHandler) SSHServerHandler {
29+ return func(s *SSHServerConnSession) error {
30 ct := time.Now()
31
32 logger := GetLogger(s)
33 if logger == slog.Default() {
34- logger = defaultLogger
35+ logger = getLogger.GetLogger(s)
36
37 user := GetUser(s)
38 if user == nil {
39 user, err := db.FindUserByPubkey(s.Permissions().Extensions["pubkey"])
40 if err == nil && user != nil {
41- logger = shared.LoggerWithUser(logger, user).With(
42+ logger = logger.With(
43+ "user", user.Name,
44+ "userId", user.ID,
45 "ip", s.RemoteAddr().String(),
46 )
47- s.Context().SetValue(ctxUserKey{}, user)
48+ s.SetValue(ctxUserKey{}, user)
49 }
50 }
51
52- s.Context().SetValue(ctxLoggerKey{}, logger)
53+ s.SetValue(ctxLoggerKey{}, logger)
54 }
55
56 pty, _, ok := s.Pty()
57@@ -51,7 +54,7 @@ func LogMiddleware(defaultLogger *slog.Logger, db FindUserInterface) wish.Middle
58 "windowHeight", pty.Window.Height,
59 )
60
61- sh(s)
62+ sshHandler(s)
63
64 logger.Info(
65 "disconnect",
66@@ -62,11 +65,13 @@ func LogMiddleware(defaultLogger *slog.Logger, db FindUserInterface) wish.Middle
67 "windowHeight", pty.Window.Height,
68 "duration", time.Since(ct),
69 )
70+
71+ return nil
72 }
73 }
74 }
75
76-func GetLogger(s ssh.Session) *slog.Logger {
77+func GetLogger(s *SSHServerConnSession) *slog.Logger {
78 logger := slog.Default()
79 if s == nil {
80 return logger
81@@ -79,7 +84,7 @@ func GetLogger(s ssh.Session) *slog.Logger {
82 return logger
83 }
84
85-func GetUser(s ssh.Session) *db.User {
86+func GetUser(s *SSHServerConnSession) *db.User {
87 if v, ok := s.Context().Value(ctxUserKey{}).(*db.User); ok {
88 return v
89 }
+35,
-0
1@@ -0,0 +1,35 @@
2+package pssh
3+
4+import (
5+ "fmt"
6+)
7+
8+func SessionMessage(sesh *SSHServerConnSession, msg string) {
9+ _, _ = sesh.Write([]byte(msg + "\r\n"))
10+}
11+
12+func DeprecatedNotice() SSHServerMiddleware {
13+ return func(next SSHServerHandler) SSHServerHandler {
14+ return func(sesh *SSHServerConnSession) error {
15+ msg := fmt.Sprintf(
16+ "%s\n\nRun %s to access pico's TUI",
17+ "DEPRECATED",
18+ "ssh pico.sh",
19+ )
20+ SessionMessage(sesh, msg)
21+ return next(sesh)
22+ }
23+ }
24+}
25+
26+func PtyMdw(mdw SSHServerMiddleware) SSHServerMiddleware {
27+ return func(next SSHServerHandler) SSHServerHandler {
28+ return func(sesh *SSHServerConnSession) error {
29+ _, _, ok := sesh.Pty()
30+ if !ok {
31+ return next(sesh)
32+ }
33+ return mdw(next)(sesh)
34+ }
35+ }
36+}
R shared/sshServer.go =>
pssh/sshServer.go
+64,
-3
1@@ -1,4 +1,4 @@
2-package shared
3+package pssh
4
5 import (
6 "context"
7@@ -72,7 +72,67 @@ func (s *SSHServerConn) SetValue(key any, data any) {
8 s.Ctx = context.WithValue(s.Ctx, key, data)
9 }
10
11-var _ context.Context = &SSHServerConn{}
12+func (s *SSHServerConn) Context() context.Context {
13+ s.mu.Lock()
14+ defer s.mu.Unlock()
15+
16+ return s.Ctx
17+}
18+
19+func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
20+ return s.Conn.Permissions
21+}
22+
23+func (s *SSHServerConnSession) User() string {
24+ return s.Conn.User()
25+}
26+
27+func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
28+ key, ok := s.Conn.Permissions.Extensions["pubkey"]
29+ if !ok {
30+ return nil
31+ }
32+
33+ pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
34+ if err != nil {
35+ return nil
36+ }
37+ return pk
38+}
39+
40+func (s *SSHServerConnSession) RemoteAddr() net.Addr {
41+ return s.Conn.RemoteAddr()
42+}
43+
44+func (s *SSHServerConnSession) Command() []string {
45+ cmd, _ := s.Value("command").([]string)
46+ return cmd
47+}
48+
49+func (s *SSHServerConnSession) Close() error {
50+ return s.Channel.Close()
51+}
52+
53+func (s *SSHServerConnSession) Exit(code int) error {
54+ _, err := s.Channel.SendRequest("exit-status", false, ssh.Marshal(struct{ C int }{code}))
55+ return err
56+}
57+
58+type Window struct {
59+ Width int
60+ Height int
61+}
62+
63+type Pty struct {
64+ Term string
65+ Window Window
66+}
67+
68+func (s *SSHServerConnSession) Pty() (Pty, <-chan Window, bool) {
69+ return Pty{}, nil, false
70+}
71+
72+var _ context.Context = &SSHServerConnSession{}
73
74 func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
75 defer sc.Close()
76@@ -146,7 +206,8 @@ func NewSSHServerConn(
77 }
78 }
79
80-type SSHServerMiddleware func(func(*SSHServerConnSession) error) func(*SSHServerConnSession) error
81+type SSHServerHandler func(*SSHServerConnSession) error
82+type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
83
84 type SSHServerConfig struct {
85 *ssh.ServerConfig
R shared/sshServer_test.go =>
pssh/sshServer_test.go
+14,
-14
1@@ -1,4 +1,4 @@
2-package shared_test
3+package pssh_test
4
5 import (
6 "context"
7@@ -8,14 +8,14 @@ import (
8 "testing"
9 "time"
10
11- "github.com/picosh/pico/shared"
12+ "github.com/picosh/pico/pssh"
13 "golang.org/x/crypto/ssh"
14 )
15
16 func TestNewSSHServer(t *testing.T) {
17 ctx := context.Background()
18 logger := slog.Default()
19- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
20+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
21
22 if server == nil {
23 t.Fatal("expected non-nil server")
24@@ -41,10 +41,10 @@ func TestNewSSHServer(t *testing.T) {
25 func TestNewSSHServerConn(t *testing.T) {
26 ctx := context.Background()
27 logger := slog.Default()
28- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
29+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
30 conn := &ssh.ServerConn{}
31
32- serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
33+ serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
34
35 if serverConn == nil {
36 t.Fatal("expected non-nil server connection")
37@@ -70,10 +70,10 @@ func TestNewSSHServerConn(t *testing.T) {
38 func TestSSHServerConnClose(t *testing.T) {
39 ctx := context.Background()
40 logger := slog.Default()
41- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
42+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
43 conn := &ssh.ServerConn{}
44
45- serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
46+ serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
47 err := serverConn.Close()
48
49 if err != nil {
50@@ -92,7 +92,7 @@ func TestSSHServerConnClose(t *testing.T) {
51 func TestSSHServerClose(t *testing.T) {
52 ctx := context.Background()
53 logger := slog.Default()
54- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
55+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
56
57 // Create a mock listener to test Close()
58 listener, err := net.Listen("tcp", "127.0.0.1:0")
59@@ -120,7 +120,7 @@ func TestSSHServerNilParams(t *testing.T) {
60 // Test with nil context and logger
61 //nolint:staticcheck // SA1012 ignores nil check
62 //lint:ignore SA1012 ignores nil check
63- server := shared.NewSSHServer(nil, nil, nil)
64+ server := pssh.NewSSHServer(nil, nil, nil)
65
66 if server == nil {
67 t.Fatal("expected non-nil server")
68@@ -137,7 +137,7 @@ func TestSSHServerNilParams(t *testing.T) {
69 // Test with nil context and logger for connection
70 //nolint:staticcheck // SA1012 ignores nil check
71 //lint:ignore SA1012 ignores nil check
72- conn := shared.NewSSHServerConn(nil, nil, &ssh.ServerConn{}, server)
73+ conn := pssh.NewSSHServerConn(nil, nil, &ssh.ServerConn{}, server)
74
75 if conn == nil {
76 t.Fatal("expected non-nil server connection")
77@@ -156,7 +156,7 @@ func TestSSHServerHandleConn(t *testing.T) {
78 ctx, cancel := context.WithCancel(context.Background())
79 defer cancel()
80 logger := slog.Default()
81- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
82+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
83
84 // Setup a basic SSH server config
85 config := &ssh.ServerConfig{
86@@ -206,7 +206,7 @@ func TestSSHServerListenAndServe(t *testing.T) {
87 ctx, cancel := context.WithCancel(context.Background())
88 defer cancel()
89 logger := slog.Default()
90- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
91+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
92
93 config := &ssh.ServerConfig{
94 NoClientAuth: true,
95@@ -245,10 +245,10 @@ func TestSSHServerConnHandle(t *testing.T) {
96 ctx, cancel := context.WithCancel(context.Background())
97 defer cancel()
98 logger := slog.Default()
99- server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
100+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
101 conn := &ssh.ServerConn{}
102
103- serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
104+ serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
105
106 // Create channels for testing
107 chans := make(chan ssh.NewChannel)