Eric Bower
·
2025-04-04
ssh_test.go
1package pgs
2
3import (
4 "crypto/ed25519"
5 "crypto/rand"
6 "encoding/pem"
7 "fmt"
8 "io"
9 "log/slog"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "strings"
14 "testing"
15 "time"
16
17 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
18 "github.com/picosh/pico/pkg/db"
19 "github.com/picosh/pico/pkg/shared/storage"
20 "github.com/picosh/utils"
21 "github.com/pkg/sftp"
22 "github.com/prometheus/client_golang/prometheus"
23 "golang.org/x/crypto/ssh"
24)
25
26func TestSshServerSftp(t *testing.T) {
27 opts := &slog.HandlerOptions{
28 AddSource: true,
29 Level: slog.LevelDebug,
30 }
31 logger := slog.New(
32 slog.NewTextHandler(os.Stdout, opts),
33 )
34 slog.SetDefault(logger)
35 dbpool := pgsdb.NewDBMemory(logger)
36 // setup test data
37 dbpool.SetupTestData()
38 st, err := storage.NewStorageMemory(map[string]map[string]string{})
39 if err != nil {
40 panic(err)
41 }
42 cfg := NewPgsConfig(logger, dbpool, st)
43 done := make(chan error)
44 prometheus.DefaultRegisterer = prometheus.NewRegistry()
45 go StartSshServer(cfg, done)
46 // Hack to wait for startup
47 time.Sleep(time.Millisecond * 100)
48
49 user := GenerateUser()
50 // add user's pubkey to the default test account
51 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
52 ID: "nice-pubkey",
53 UserID: dbpool.Users[0].ID,
54 Key: utils.KeyForKeyText(user.signer.PublicKey()),
55 })
56
57 client, err := user.NewClient()
58 if err != nil {
59 t.Error(err)
60 return
61 }
62 defer client.Close()
63
64 _, err = WriteFileWithSftp(cfg, client)
65 if err != nil {
66 t.Error(err)
67 return
68 }
69
70 _, err = WriteFilesMultProjectsWithSftp(cfg, client)
71 if err != nil {
72 t.Error(err)
73 return
74 }
75
76 projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
77 if err != nil {
78 t.Error(err)
79 return
80 }
81
82 names := ""
83 for _, proj := range projects {
84 names += "_" + proj.Name
85 }
86
87 if names != "_test_mult_mult2" {
88 t.Errorf("not all projects created: %s", names)
89 return
90 }
91
92 close(done)
93
94 p, err := os.FindProcess(os.Getpid())
95 if err != nil {
96 t.Fatal(err)
97 return
98 }
99
100 err = p.Signal(os.Interrupt)
101 if err != nil {
102 t.Fatal(err)
103 return
104 }
105 <-time.After(10 * time.Millisecond)
106}
107
108func TestSshServerRsync(t *testing.T) {
109 opts := &slog.HandlerOptions{
110 AddSource: true,
111 Level: slog.LevelDebug,
112 }
113 logger := slog.New(
114 slog.NewTextHandler(os.Stdout, opts),
115 )
116 slog.SetDefault(logger)
117 dbpool := pgsdb.NewDBMemory(logger)
118 // setup test data
119 dbpool.SetupTestData()
120 st, err := storage.NewStorageMemory(map[string]map[string]string{})
121 if err != nil {
122 panic(err)
123 }
124 cfg := NewPgsConfig(logger, dbpool, st)
125 done := make(chan error)
126 prometheus.DefaultRegisterer = prometheus.NewRegistry()
127 go StartSshServer(cfg, done)
128 // Hack to wait for startup
129 time.Sleep(time.Millisecond * 100)
130
131 user := GenerateUser()
132 key := utils.KeyForKeyText(user.signer.PublicKey())
133 // add user's pubkey to the default test account
134 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
135 ID: "nice-pubkey",
136 UserID: dbpool.Users[0].ID,
137 Key: key,
138 })
139
140 conn, err := user.NewClient()
141 if err != nil {
142 t.Error(err)
143 return
144 }
145 defer conn.Close()
146
147 // open an SFTP session over an existing ssh connection.
148 client, err := sftp.NewClient(conn)
149 if err != nil {
150 cfg.Logger.Error("could not create sftp client", "err", err)
151 panic(err)
152 }
153 defer client.Close()
154
155 name, err := os.MkdirTemp("", "rsync-")
156 if err != nil {
157 panic(err)
158 }
159
160 // remove the temporary directory at the end of the program
161 // defer os.RemoveAll(name)
162
163 block := &pem.Block{
164 Type: "OPENSSH PRIVATE KEY",
165 Bytes: user.privateKey,
166 }
167 keyFile := filepath.Join(name, "id_ed25519")
168 err = os.WriteFile(
169 keyFile,
170 pem.EncodeToMemory(block), 0600,
171 )
172 if err != nil {
173 t.Fatal(err)
174 }
175
176 index := "<!doctype html><html><body>index</body></html>"
177 err = os.WriteFile(
178 filepath.Join(name, "index.html"),
179 []byte(index), 0666,
180 )
181 if err != nil {
182 t.Fatal(err)
183 }
184
185 about := "<!doctype html><html><body>about</body></html>"
186 aboutFile := filepath.Join(name, "about.html")
187 err = os.WriteFile(
188 aboutFile,
189 []byte(about), 0666,
190 )
191 if err != nil {
192 t.Fatal(err)
193 }
194
195 contact := "<!doctype html><html><body>contact</body></html>"
196 err = os.WriteFile(
197 filepath.Join(name, "contact.html"),
198 []byte(contact), 0666,
199 )
200 if err != nil {
201 t.Fatal(err)
202 }
203
204 eCmd := fmt.Sprintf(
205 "ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
206 keyFile,
207 )
208
209 // copy files
210 cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
211 result, err := cmd.CombinedOutput()
212 if err != nil {
213 cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
214 t.Error(err)
215 return
216 }
217
218 // check it's there
219 fi, err := client.Lstat("/test/about.html")
220 if err != nil {
221 cfg.Logger.Error("could not get stat for file", "err", err)
222 t.Error("about.html not found")
223 return
224 }
225 if fi.Size() != 46 {
226 cfg.Logger.Error("about.html wrong size", "size", fi.Size())
227 t.Error("about.html wrong size")
228 return
229 }
230
231 // remove about file
232 os.Remove(aboutFile)
233
234 // copy files with delete
235 delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
236 result, err = delCmd.CombinedOutput()
237 if err != nil {
238 cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
239 t.Error(err)
240 return
241 }
242
243 // check it's not there
244 _, err = client.Lstat("/test/about.html")
245 if err == nil {
246 cfg.Logger.Error("file still exists")
247 t.Error("about.html found")
248 return
249 }
250
251 dlName, err := os.MkdirTemp("", "rsync-download")
252 if err != nil {
253 panic(err)
254 }
255 defer os.RemoveAll(dlName)
256 // download files
257 downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
258 result, err = downloadCmd.CombinedOutput()
259 if err != nil {
260 cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
261 t.Error(err)
262 return
263 }
264 // check contents
265 idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
266 if err != nil {
267 cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
268 t.Error(err)
269 return
270 }
271 if string(idx) != index {
272 t.Error("downloaded index.html file does not match original")
273 return
274 }
275 _, err = os.ReadFile(filepath.Join(dlName, "about.html"))
276 if err == nil {
277 cfg.Logger.Error("about file should not exist", "file", "about.html")
278 t.Error(err)
279 return
280 }
281 cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
282 if err != nil {
283 cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
284 t.Error(err)
285 return
286 }
287 if string(cnt) != contact {
288 t.Error("downloaded contact.html file does not match original")
289 return
290 }
291
292 close(done)
293
294 p, err := os.FindProcess(os.Getpid())
295 if err != nil {
296 t.Fatal(err)
297 return
298 }
299
300 err = p.Signal(os.Interrupt)
301 if err != nil {
302 t.Fatal(err)
303 return
304 }
305 <-time.After(10 * time.Millisecond)
306}
307
308type UserSSH struct {
309 username string
310 signer ssh.Signer
311 privateKey []byte
312}
313
314func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
315 return &UserSSH{
316 username: username,
317 signer: signer,
318 }
319}
320
321func (s UserSSH) Public() string {
322 pubkey := s.signer.PublicKey()
323 return string(ssh.MarshalAuthorizedKey(pubkey))
324}
325
326func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
327 res, err := s.Cmd(client, patch, cmd)
328 if err != nil {
329 panic(err)
330 }
331 return res
332}
333
334func (s UserSSH) NewClient() (*ssh.Client, error) {
335 host := "localhost:2222"
336
337 config := &ssh.ClientConfig{
338 User: s.username,
339 Auth: []ssh.AuthMethod{
340 ssh.PublicKeys(s.signer),
341 },
342 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
343 }
344
345 client, err := ssh.Dial("tcp", host, config)
346 return client, err
347}
348
349func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
350 session, err := client.NewSession()
351 if err != nil {
352 return "", err
353 }
354 defer session.Close()
355
356 stdinPipe, err := session.StdinPipe()
357 if err != nil {
358 return "", err
359 }
360
361 stdoutPipe, err := session.StdoutPipe()
362 if err != nil {
363 return "", err
364 }
365
366 if err := session.Start(cmd); err != nil {
367 return "", err
368 }
369
370 if patch != nil {
371 _, err = stdinPipe.Write(patch)
372 if err != nil {
373 return "", err
374 }
375 }
376
377 stdinPipe.Close()
378
379 if err := session.Wait(); err != nil {
380 return "", err
381 }
382
383 buf := new(strings.Builder)
384 _, err = io.Copy(buf, stdoutPipe)
385 if err != nil {
386 return "", err
387 }
388
389 return buf.String(), nil
390}
391
392func GenerateUser() UserSSH {
393 _, userKey, err := ed25519.GenerateKey(rand.Reader)
394 if err != nil {
395 panic(err)
396 }
397
398 b, err := ssh.MarshalPrivateKey(userKey, "")
399 if err != nil {
400 panic(err)
401 }
402
403 userSigner, err := ssh.NewSignerFromKey(userKey)
404 if err != nil {
405 panic(err)
406 }
407
408 return UserSSH{
409 username: "testuser",
410 signer: userSigner,
411 privateKey: b.Bytes,
412 }
413}
414
415func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
416 // open an SFTP session over an existing ssh connection.
417 client, err := sftp.NewClient(conn)
418 if err != nil {
419 cfg.Logger.Error("could not create sftp client", "err", err)
420 return nil, err
421 }
422 defer client.Close()
423
424 f, err := client.Create("test/hello.txt")
425 if err != nil {
426 cfg.Logger.Error("could not create file", "err", err)
427 return nil, err
428 }
429 if _, err := f.Write([]byte("Hello world!")); err != nil {
430 cfg.Logger.Error("could not write to file", "err", err)
431 return nil, err
432 }
433
434 cfg.Logger.Info("closing", "err", f.Close())
435
436 // check it's there
437 fi, err := client.Lstat("test/hello.txt")
438 if err != nil {
439 cfg.Logger.Error("could not get stat for file", "err", err)
440 return nil, err
441 }
442
443 return &fi, nil
444}
445
446func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
447 // open an SFTP session over an existing ssh connection.
448 client, err := sftp.NewClient(conn)
449 if err != nil {
450 cfg.Logger.Error("could not create sftp client", "err", err)
451 return nil, err
452 }
453 defer client.Close()
454
455 f, err := client.Create("mult/hello.txt")
456 if err != nil {
457 cfg.Logger.Error("could not create file", "err", err)
458 return nil, err
459 }
460 if _, err := f.Write([]byte("Hello world!")); err != nil {
461 cfg.Logger.Error("could not write to file", "err", err)
462 return nil, err
463 }
464
465 f, err = client.Create("mult2/hello.txt")
466 if err != nil {
467 cfg.Logger.Error("could not create file", "err", err)
468 return nil, err
469 }
470 if _, err := f.Write([]byte("Hello world!")); err != nil {
471 cfg.Logger.Error("could not write to file", "err", err)
472 return nil, err
473 }
474
475 cfg.Logger.Info("closing", "err", f.Close())
476
477 // check it's there
478 fi, err := client.Lstat("test/hello.txt")
479 if err != nil {
480 cfg.Logger.Error("could not get stat for file", "err", err)
481 return nil, err
482 }
483
484 return &fi, nil
485}