repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / send / protocols / sftp
Antonio Mika  ·  2025-03-12

handler.go

  1package sftp
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"io/fs"
  9	"os"
 10	"path/filepath"
 11
 12	"slices"
 13
 14	"github.com/picosh/pico/pkg/pssh"
 15	"github.com/picosh/pico/pkg/send/utils"
 16	"github.com/pkg/sftp"
 17)
 18
 19type listerat []os.FileInfo
 20
 21func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
 22	var n int
 23	if offset >= int64(len(f)) {
 24		return 0, io.EOF
 25	}
 26	n = copy(ls, f[offset:])
 27	if n < len(ls) {
 28		return n, io.EOF
 29	}
 30	return n, nil
 31}
 32
 33type handler struct {
 34	session      *pssh.SSHServerConnSession
 35	writeHandler utils.CopyFromClientHandler
 36}
 37
 38func (f *handler) Filecmd(r *sftp.Request) error {
 39	switch r.Method {
 40	case "Rmdir", "Remove":
 41		entry := toFileEntry(r)
 42
 43		if r.Method == "Rmdir" {
 44			entry.Mode = os.ModeDir
 45		}
 46
 47		return f.writeHandler.Delete(f.session, entry)
 48	case "Mkdir":
 49		entry := toFileEntry(r)
 50
 51		entry.Mode = os.ModeDir
 52
 53		_, err := f.writeHandler.Write(f.session, entry)
 54
 55		return err
 56	case "Setstat":
 57		return nil
 58	}
 59	return errors.New("unsupported")
 60}
 61
 62func (f *handler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
 63	switch r.Method {
 64	case "List", "Stat":
 65		list := r.Method == "List"
 66
 67		listData, err := f.writeHandler.List(f.session, r.Filepath, list, false)
 68		if err != nil {
 69			return nil, err
 70		}
 71
 72		// an empty string from minio or exact match from filepath base name is what we want
 73
 74		if !list {
 75			listData = slices.DeleteFunc(listData, func(f os.FileInfo) bool {
 76				return !(f.Name() == "" || f.Name() == filepath.Base(r.Filepath))
 77			})
 78		}
 79
 80		if r.Filepath == "/" {
 81			listData = slices.DeleteFunc(listData, func(f os.FileInfo) bool {
 82				return f.Name() == "/"
 83			})
 84			listData = slices.Insert(listData, 0, os.FileInfo(&utils.VirtualFile{
 85				FName:  ".",
 86				FIsDir: true,
 87			}))
 88		}
 89
 90		return listerat(listData), nil
 91	}
 92
 93	return nil, errors.New("unsupported")
 94}
 95
 96func toFileEntry(r *sftp.Request) *utils.FileEntry {
 97	attrs := r.Attributes()
 98	var size int64 = 0
 99	var mtime int64 = 0
100	var atime int64 = 0
101	var mode fs.FileMode
102	if attrs != nil {
103		mode = attrs.FileMode()
104		size = int64(attrs.Size)
105		mtime = int64(attrs.Mtime)
106		atime = int64(attrs.Atime)
107	}
108
109	entry := &utils.FileEntry{
110		Filepath: r.Filepath,
111		Mode:     mode,
112		Size:     size,
113		Mtime:    mtime,
114		Atime:    atime,
115	}
116	return entry
117}
118
119func (f *handler) Filewrite(r *sftp.Request) (io.WriterAt, error) {
120	entry := toFileEntry(r)
121	entry.Reader = bytes.NewReader([]byte{})
122
123	_, err := f.writeHandler.Write(f.session, entry)
124	if err != nil {
125		return nil, err
126	}
127
128	buf := &buffer{}
129	entry.Reader = buf
130
131	return fakeWrite{fileEntry: entry, buf: buf, handler: f}, nil
132}
133
134func (f *handler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
135	if r.Filepath == "/" {
136		return nil, os.ErrInvalid
137	}
138
139	fileEntry := toFileEntry(r)
140	_, reader, err := f.writeHandler.Read(f.session, fileEntry)
141
142	return reader, err
143}
144
145type handlererr struct {
146	Handler *handler
147}
148
149func (f *handlererr) Filecmd(r *sftp.Request) error {
150	err := f.Handler.Filecmd(r)
151	if err != nil {
152		fmt.Fprintln(f.Handler.session.Stderr(), err)
153	}
154	return err
155}
156func (f *handlererr) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
157	result, err := f.Handler.Filelist(r)
158	if err != nil {
159		fmt.Fprintln(f.Handler.session.Stderr(), err)
160	}
161	return result, err
162}
163func (f *handlererr) Filewrite(r *sftp.Request) (io.WriterAt, error) {
164	result, err := f.Handler.Filewrite(r)
165	if err != nil {
166		fmt.Fprintln(f.Handler.session.Stderr(), err)
167	}
168	return result, err
169}
170func (f *handlererr) Fileread(r *sftp.Request) (io.ReaderAt, error) {
171	result, err := f.Handler.Fileread(r)
172	if err != nil {
173		fmt.Fprintln(f.Handler.session.Stderr(), err)
174	}
175	return result, err
176}