Eric Bower
·
2025-05-25
rsync.go
1package rsync
2
3import (
4 "errors"
5 "fmt"
6 "io/fs"
7 "os"
8 "path"
9 "slices"
10 "strings"
11
12 "github.com/picosh/go-rsync-receiver/rsyncopts"
13 "github.com/picosh/go-rsync-receiver/rsyncreceiver"
14 "github.com/picosh/go-rsync-receiver/rsyncsender"
15 rsyncutils "github.com/picosh/go-rsync-receiver/utils"
16 "github.com/picosh/pico/pkg/pssh"
17 "github.com/picosh/pico/pkg/send/utils"
18)
19
20type handler struct {
21 session *pssh.SSHServerConnSession
22 writeHandler utils.CopyFromClientHandler
23 root string
24 recursive bool
25 ignoreTimes bool
26}
27
28func (h *handler) List(rPath string) ([]fs.FileInfo, error) {
29 isDir := false
30 if rPath == "." {
31 rPath = "/"
32 isDir = true
33 }
34
35 list, err := h.writeHandler.List(h.session, rPath, isDir, h.recursive)
36 if err != nil {
37 return nil, err
38 }
39
40 var dirs []string
41
42 var newList []fs.FileInfo
43
44 for _, f := range list {
45 if !f.IsDir() && f.Size() == 0 {
46 continue
47 }
48
49 fname := f.Name()
50 if strings.HasPrefix(f.Name(), "/") {
51 fname = path.Join(rPath, f.Name())
52 }
53
54 if fname == "" && !f.IsDir() {
55 fname = path.Base(rPath)
56 }
57
58 newFile := &utils.VirtualFile{
59 FName: fname,
60 FIsDir: f.IsDir(),
61 FSize: f.Size(),
62 FModTime: f.ModTime(),
63 FSys: f.Sys(),
64 }
65
66 newList = append(newList, newFile)
67
68 parts := strings.Split(newFile.Name(), string(os.PathSeparator))
69 lastDir := newFile.Name()
70 for i := 0; i < len(parts); i++ {
71 lastDir, _ = path.Split(lastDir)
72 if lastDir == "" {
73 continue
74 }
75
76 lastDir = lastDir[:len(lastDir)-1]
77 dirs = append(dirs, lastDir)
78 }
79 }
80
81 for _, dir := range dirs {
82 newList = append(newList, &utils.VirtualFile{
83 FName: dir,
84 FIsDir: true,
85 })
86 }
87
88 slices.Reverse(newList)
89
90 onlyEmpty := true
91 for _, f := range newList {
92 if f.Name() != "" {
93 onlyEmpty = false
94 }
95 }
96
97 if len(newList) == 0 || onlyEmpty {
98 return nil, errors.New("no files to send, the directory may not exist or could be empty")
99 }
100
101 return newList, nil
102}
103
104func (h *handler) Read(file *rsyncutils.SenderFile) (os.FileInfo, rsyncutils.ReaderAtCloser, error) {
105 filePath := file.WPath
106
107 if strings.HasSuffix(h.root, file.WPath) {
108 filePath = h.root
109 } else if !strings.HasPrefix(filePath, h.root) {
110 filePath = path.Join(h.root, file.Path, file.WPath)
111 }
112
113 return h.writeHandler.Read(h.session, &utils.FileEntry{Filepath: filePath})
114}
115
116func (h *handler) Put(file *rsyncutils.ReceiverFile) (int64, error) {
117 fileEntry := &utils.FileEntry{
118 Filepath: path.Join("/", h.root, file.Name),
119 Mode: fs.FileMode(0600),
120 Size: file.Length,
121 Mtime: file.ModTime.Unix(),
122 Atime: file.ModTime.Unix(),
123 }
124 fileEntry.Reader = file.Reader
125
126 msg, err := h.writeHandler.Write(h.session, fileEntry)
127 if err != nil {
128 errMsg := fmt.Sprintf("%s\r\n", err.Error())
129 _, err = h.session.Stderr().Write([]byte(errMsg))
130 }
131 if msg != "" {
132 nMsg := fmt.Sprintf("%s\r\n", msg)
133 _, err = h.session.Stderr().Write([]byte(nMsg))
134 }
135 return 0, err
136}
137
138func (h *handler) Remove(willReceive []*rsyncutils.ReceiverFile) error {
139 entries, err := h.writeHandler.List(h.session, path.Join("/", h.root), true, true)
140 if err != nil {
141 return err
142 }
143
144 var toDelete []string
145
146 for _, entry := range entries {
147 exists := slices.ContainsFunc(willReceive, func(rf *rsyncutils.ReceiverFile) bool {
148 return rf.Name == entry.Name()
149 })
150
151 if !exists && entry.Name() != "._pico_keep_dir" {
152 toDelete = append(toDelete, entry.Name())
153 }
154 }
155
156 var errs []error
157
158 for _, file := range toDelete {
159 errs = append(errs, h.writeHandler.Delete(h.session, &utils.FileEntry{Filepath: path.Join("/", h.root, file)}))
160 _, err = fmt.Fprintf(h.session.Stderr(), "deleting %s\r\n", file)
161 errs = append(errs, err)
162 }
163
164 return errors.Join(errs...)
165}
166
167func Middleware(writeHandler utils.CopyFromClientHandler) pssh.SSHServerMiddleware {
168 return func(sshHandler pssh.SSHServerHandler) pssh.SSHServerHandler {
169 return func(session *pssh.SSHServerConnSession) error {
170 cmd := session.Command()
171 if len(cmd) == 0 || cmd[0] != "rsync" {
172 return sshHandler(session)
173 }
174
175 logger := writeHandler.GetLogger(session).With(
176 "rsync", true,
177 "cmd", cmd,
178 )
179
180 defer func() {
181 if r := recover(); r != nil {
182 logger.Error("error running rsync middleware", "err", r)
183 _, _ = session.Stderr().Write([]byte("error running rsync middleware, check the flags you are using\r\n"))
184 }
185 }()
186
187 cmdFlags := session.Command()
188 flgs := cmdFlags[1:]
189 for idx, f := range flgs {
190 // openrsync sends "delete-before" when the client provided "delete"
191 flgs[idx] = strings.ReplaceAll(f, "delete-before", "delete")
192 }
193
194 optsCtx, err := rsyncopts.ParseArguments(cmdFlags[1:], true)
195 if err != nil {
196 _, _ = fmt.Fprintf(session.Stderr(), "error parsing rsync arguments: %s\r\n", err.Error())
197 return err
198 }
199
200 if optsCtx.Options.Compress() {
201 err := fmt.Errorf("compression is currently unsupported")
202 _, _ = fmt.Fprintf(session.Stderr(), "error: %s\r\n", err.Error())
203 return err
204 }
205
206 if optsCtx.Options.AlwaysChecksum() {
207 err := fmt.Errorf("checksum is currently unsupported")
208 _, _ = fmt.Fprintf(session.Stderr(), "error: %s\r\n", err.Error())
209 return err
210 }
211
212 if len(optsCtx.RemainingArgs) != 2 {
213 err := fmt.Errorf("missing source and destination arguments")
214 _, _ = fmt.Fprintf(session.Stderr(), "error: %s\r\n", err.Error())
215 return err
216 }
217
218 root := strings.TrimPrefix(optsCtx.RemainingArgs[len(optsCtx.RemainingArgs)-1], "/")
219 if root == "" {
220 root = "/"
221 }
222
223 fileHandler := &handler{
224 session: session,
225 writeHandler: writeHandler,
226 root: root,
227 recursive: optsCtx.Options.Recurse(),
228 ignoreTimes: !optsCtx.Options.PreserveMTimes(),
229 }
230
231 for _, arg := range cmd {
232 if arg == "--sender" {
233 err := rsyncsender.ClientRun(logger, optsCtx.Options, session, fileHandler, []string{fileHandler.root}, true)
234 if err != nil {
235 logger.Error("error running rsync sender", "err", err)
236 }
237 return err
238 }
239 }
240
241 err = rsyncreceiver.ClientRun(logger, optsCtx.Options, session, fileHandler, []string{fileHandler.root}, true)
242 if err != nil {
243 logger.Error("error running rsync receiver", "err", err)
244 }
245
246 return err
247 }
248 }
249}