repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / send / protocols / sftp
Eric Bower  ·  2026-01-16

handler.go

  1package sftp
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"io/fs"
  9	"os"
 10	"path/filepath"
 11
 12	"slices"
 13
 14	"github.com/picosh/pico/pkg/pssh"
 15	"github.com/picosh/pico/pkg/send/utils"
 16	"github.com/pkg/sftp"
 17)
 18
 19type listerat []os.FileInfo
 20
 21func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
 22	var n int
 23	if offset >= int64(len(f)) {
 24		return 0, io.EOF
 25	}
 26	n = copy(ls, f[offset:])
 27	if n < len(ls) {
 28		return n, io.EOF
 29	}
 30	return n, nil
 31}
 32
 33type handler struct {
 34	session      *pssh.SSHServerConnSession
 35	writeHandler utils.CopyFromClientHandler
 36}
 37
 38func (f *handler) Filecmd(r *sftp.Request) error {
 39	switch r.Method {
 40	case "Rmdir", "Remove":
 41		entry := toFileEntry(r)
 42
 43		if r.Method == "Rmdir" {
 44			entry.Mode = os.ModeDir
 45		}
 46
 47		return f.writeHandler.Delete(f.session, entry)
 48	case "Mkdir":
 49		entry := toFileEntry(r)
 50
 51		entry.Mode = os.ModeDir
 52
 53		_, err := f.writeHandler.Write(f.session, entry)
 54
 55		return err
 56	case "Setstat":
 57		return nil
 58	}
 59	return errors.New("unsupported")
 60}
 61
 62func (f *handler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
 63	switch r.Method {
 64	case "List", "Stat":
 65		list := r.Method == "List"
 66
 67		listData, err := f.writeHandler.List(f.session, r.Filepath, list, false)
 68		if err != nil {
 69			return nil, err
 70		}
 71
 72		// an empty string or exact match from filepath base name is what we want
 73		if !list {
 74			listData = slices.DeleteFunc(listData, func(f os.FileInfo) bool {
 75				return f.Name() != "" && f.Name() != filepath.Base(r.Filepath)
 76			})
 77		}
 78
 79		if r.Filepath == "/" {
 80			listData = slices.DeleteFunc(listData, func(f os.FileInfo) bool {
 81				return f.Name() == "/"
 82			})
 83			listData = slices.Insert(listData, 0, os.FileInfo(&utils.VirtualFile{
 84				FName:  ".",
 85				FIsDir: true,
 86			}))
 87		}
 88
 89		return listerat(listData), nil
 90	}
 91
 92	return nil, errors.New("unsupported")
 93}
 94
 95func toFileEntry(r *sftp.Request) *utils.FileEntry {
 96	attrs := r.Attributes()
 97	var size int64 = 0
 98	var mtime int64 = 0
 99	var atime int64 = 0
100	var mode fs.FileMode
101	if attrs != nil {
102		mode = attrs.FileMode()
103		size = int64(attrs.Size)
104		mtime = int64(attrs.Mtime)
105		atime = int64(attrs.Atime)
106	}
107
108	entry := &utils.FileEntry{
109		Filepath: r.Filepath,
110		Mode:     mode,
111		Size:     size,
112		Mtime:    mtime,
113		Atime:    atime,
114	}
115	return entry
116}
117
118func (f *handler) Filewrite(r *sftp.Request) (io.WriterAt, error) {
119	entry := toFileEntry(r)
120	entry.Reader = bytes.NewReader([]byte{})
121
122	_, err := f.writeHandler.Write(f.session, entry)
123	if err != nil {
124		return nil, err
125	}
126
127	buf := &buffer{}
128	entry.Reader = buf
129
130	return fakeWrite{fileEntry: entry, buf: buf, handler: f}, nil
131}
132
133func (f *handler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
134	if r.Filepath == "/" {
135		return nil, os.ErrInvalid
136	}
137
138	fileEntry := toFileEntry(r)
139	_, reader, err := f.writeHandler.Read(f.session, fileEntry)
140
141	return reader, err
142}
143
144type handlererr struct {
145	Handler *handler
146}
147
148func (f *handlererr) Filecmd(r *sftp.Request) error {
149	err := f.Handler.Filecmd(r)
150	if err != nil {
151		_, _ = fmt.Fprintln(f.Handler.session.Stderr(), err)
152	}
153	return err
154}
155func (f *handlererr) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
156	result, err := f.Handler.Filelist(r)
157	if err != nil {
158		_, _ = fmt.Fprintln(f.Handler.session.Stderr(), err)
159	}
160	return result, err
161}
162func (f *handlererr) Filewrite(r *sftp.Request) (io.WriterAt, error) {
163	result, err := f.Handler.Filewrite(r)
164	if err != nil {
165		_, _ = fmt.Fprintln(f.Handler.session.Stderr(), err)
166	}
167	return result, err
168}
169func (f *handlererr) Fileread(r *sftp.Request) (io.ReaderAt, error) {
170	result, err := f.Handler.Fileread(r)
171	if err != nil {
172		_, _ = fmt.Fprintln(f.Handler.session.Stderr(), err)
173	}
174	return result, err
175}