Antonio Mika
·
2025-03-12
rsync.go
1package rsync
2
3import (
4 "errors"
5 "fmt"
6 "io/fs"
7 "os"
8 "path"
9 "slices"
10 "strings"
11
12 "github.com/picosh/go-rsync-receiver/rsyncopts"
13 "github.com/picosh/go-rsync-receiver/rsyncreceiver"
14 "github.com/picosh/go-rsync-receiver/rsyncsender"
15 rsyncutils "github.com/picosh/go-rsync-receiver/utils"
16 "github.com/picosh/pico/pkg/pssh"
17 "github.com/picosh/pico/pkg/send/utils"
18)
19
20type handler struct {
21 session *pssh.SSHServerConnSession
22 writeHandler utils.CopyFromClientHandler
23 root string
24 recursive bool
25 ignoreTimes bool
26}
27
28func (h *handler) List(rPath string) ([]fs.FileInfo, error) {
29 isDir := false
30 if rPath == "." {
31 rPath = "/"
32 isDir = true
33 }
34
35 list, err := h.writeHandler.List(h.session, rPath, isDir, h.recursive)
36 if err != nil {
37 return nil, err
38 }
39
40 var dirs []string
41
42 var newList []fs.FileInfo
43
44 for _, f := range list {
45 if !f.IsDir() && f.Size() == 0 {
46 continue
47 }
48
49 fname := f.Name()
50 if strings.HasPrefix(f.Name(), "/") {
51 fname = path.Join(rPath, f.Name())
52 }
53
54 if fname == "" && !f.IsDir() {
55 fname = path.Base(rPath)
56 }
57
58 newFile := &utils.VirtualFile{
59 FName: fname,
60 FIsDir: f.IsDir(),
61 FSize: f.Size(),
62 FModTime: f.ModTime(),
63 FSys: f.Sys(),
64 }
65
66 newList = append(newList, newFile)
67
68 parts := strings.Split(newFile.Name(), string(os.PathSeparator))
69 lastDir := newFile.Name()
70 for i := 0; i < len(parts); i++ {
71 lastDir, _ = path.Split(lastDir)
72 if lastDir == "" {
73 continue
74 }
75
76 lastDir = lastDir[:len(lastDir)-1]
77 dirs = append(dirs, lastDir)
78 }
79 }
80
81 for _, dir := range dirs {
82 newList = append(newList, &utils.VirtualFile{
83 FName: dir,
84 FIsDir: true,
85 })
86 }
87
88 slices.Reverse(newList)
89
90 onlyEmpty := true
91 for _, f := range newList {
92 if f.Name() != "" {
93 onlyEmpty = false
94 }
95 }
96
97 if len(newList) == 0 || onlyEmpty {
98 return nil, errors.New("no files to send, the directory may not exist or could be empty")
99 }
100
101 return newList, nil
102}
103
104func (h *handler) Read(file *rsyncutils.SenderFile) (os.FileInfo, rsyncutils.ReaderAtCloser, error) {
105 filePath := file.WPath
106
107 if strings.HasSuffix(h.root, file.WPath) {
108 filePath = h.root
109 } else if !strings.HasPrefix(filePath, h.root) {
110 filePath = path.Join(h.root, file.Path, file.WPath)
111 }
112
113 return h.writeHandler.Read(h.session, &utils.FileEntry{Filepath: filePath})
114}
115
116func (h *handler) Put(file *rsyncutils.ReceiverFile) (int64, error) {
117 fileEntry := &utils.FileEntry{
118 Filepath: path.Join("/", h.root, file.Name),
119 Mode: fs.FileMode(0600),
120 Size: file.Length,
121 Mtime: file.ModTime.Unix(),
122 Atime: file.ModTime.Unix(),
123 }
124 fileEntry.Reader = file.Reader
125
126 msg, err := h.writeHandler.Write(h.session, fileEntry)
127 if err != nil {
128 errMsg := fmt.Sprintf("%s\r\n", err.Error())
129 _, err = h.session.Stderr().Write([]byte(errMsg))
130 }
131 if msg != "" {
132 nMsg := fmt.Sprintf("%s\r\n", msg)
133 _, err = h.session.Stderr().Write([]byte(nMsg))
134 }
135 return 0, err
136}
137
138func (h *handler) Remove(willReceive []*rsyncutils.ReceiverFile) error {
139 entries, err := h.writeHandler.List(h.session, path.Join("/", h.root), true, true)
140 if err != nil {
141 return err
142 }
143
144 var toDelete []string
145
146 for _, entry := range entries {
147 exists := slices.ContainsFunc(willReceive, func(rf *rsyncutils.ReceiverFile) bool {
148 return rf.Name == entry.Name()
149 })
150
151 if !exists && entry.Name() != "._pico_keep_dir" {
152 toDelete = append(toDelete, entry.Name())
153 }
154 }
155
156 var errs []error
157
158 for _, file := range toDelete {
159 errs = append(errs, h.writeHandler.Delete(h.session, &utils.FileEntry{Filepath: path.Join("/", h.root, file)}))
160 _, err = h.session.Stderr().Write([]byte(fmt.Sprintf("deleting %s\r\n", file)))
161 errs = append(errs, err)
162 }
163
164 return errors.Join(errs...)
165}
166
167func Middleware(writeHandler utils.CopyFromClientHandler) pssh.SSHServerMiddleware {
168 return func(sshHandler pssh.SSHServerHandler) pssh.SSHServerHandler {
169 return func(session *pssh.SSHServerConnSession) error {
170 cmd := session.Command()
171 if len(cmd) == 0 || cmd[0] != "rsync" {
172 return sshHandler(session)
173 }
174
175 logger := writeHandler.GetLogger(session).With(
176 "rsync", true,
177 "cmd", cmd,
178 )
179
180 defer func() {
181 if r := recover(); r != nil {
182 logger.Error("error running rsync middleware", "err", r)
183 _, _ = session.Stderr().Write([]byte("error running rsync middleware, check the flags you are using\r\n"))
184 }
185 }()
186
187 cmdFlags := session.Command()
188
189 optsCtx, err := rsyncopts.ParseArguments(cmdFlags[1:], true)
190 if err != nil {
191 fmt.Fprintf(session.Stderr(), "error parsing rsync arguments: %s\r\n", err.Error())
192 return err
193 }
194
195 if optsCtx.Options.Compress() {
196 err := fmt.Errorf("compression is currently unsupported")
197 fmt.Fprintf(session.Stderr(), "error: %s\r\n", err.Error())
198 return err
199 }
200
201 if optsCtx.Options.AlwaysChecksum() {
202 err := fmt.Errorf("checksum is currently unsupported")
203 fmt.Fprintf(session.Stderr(), "error: %s\r\n", err.Error())
204 return err
205 }
206
207 if len(optsCtx.RemainingArgs) != 2 {
208 err := fmt.Errorf("missing source and destination arguments")
209 fmt.Fprintf(session.Stderr(), "error: %s\r\n", err.Error())
210 return err
211 }
212
213 root := strings.TrimPrefix(optsCtx.RemainingArgs[len(optsCtx.RemainingArgs)-1], "/")
214 if root == "" {
215 root = "/"
216 }
217
218 fileHandler := &handler{
219 session: session,
220 writeHandler: writeHandler,
221 root: root,
222 recursive: optsCtx.Options.Recurse(),
223 ignoreTimes: !optsCtx.Options.PreserveMTimes(),
224 }
225
226 for _, arg := range cmd {
227 if arg == "--sender" {
228 err := rsyncsender.ClientRun(logger, optsCtx.Options, session, fileHandler, []string{fileHandler.root}, true)
229 if err != nil {
230 logger.Error("error running rsync sender", "err", err)
231 }
232 return err
233 }
234 }
235
236 err = rsyncreceiver.ClientRun(logger, optsCtx.Options, session, fileHandler, []string{fileHandler.root}, true)
237 if err != nil {
238 logger.Error("error running rsync receiver", "err", err)
239 }
240
241 return err
242 }
243 }
244}