Eric Bower
·
2025-07-04
ssh_test.go
1package pgs
2
3import (
4 "crypto/ed25519"
5 "crypto/rand"
6 "encoding/pem"
7 "fmt"
8 "io"
9 "log/slog"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "strings"
14 "testing"
15 "time"
16
17 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
18 "github.com/picosh/pico/pkg/db"
19 "github.com/picosh/pico/pkg/shared/storage"
20 "github.com/picosh/utils"
21 "github.com/pkg/sftp"
22 "github.com/prometheus/client_golang/prometheus"
23 "golang.org/x/crypto/ssh"
24)
25
26func TestSshServerSftp(t *testing.T) {
27 opts := &slog.HandlerOptions{
28 AddSource: true,
29 Level: slog.LevelDebug,
30 }
31 logger := slog.New(
32 slog.NewTextHandler(os.Stdout, opts),
33 )
34 slog.SetDefault(logger)
35 dbpool := pgsdb.NewDBMemory(logger)
36 // setup test data
37 dbpool.SetupTestData()
38 st, err := storage.NewStorageMemory(map[string]map[string]string{})
39 if err != nil {
40 panic(err)
41 }
42 pubsub := NewPubsubChan()
43 defer func() {
44 _ = pubsub.Close()
45 }()
46 cfg := NewPgsConfig(logger, dbpool, st, pubsub)
47 done := make(chan error)
48 prometheus.DefaultRegisterer = prometheus.NewRegistry()
49 go StartSshServer(cfg, done)
50 // Hack to wait for startup
51 time.Sleep(time.Millisecond * 100)
52
53 user := GenerateUser()
54 // add user's pubkey to the default test account
55 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
56 ID: "nice-pubkey",
57 UserID: dbpool.Users[0].ID,
58 Key: utils.KeyForKeyText(user.signer.PublicKey()),
59 })
60
61 client, err := user.NewClient()
62 if err != nil {
63 t.Error(err)
64 return
65 }
66 defer func() {
67 _ = client.Close()
68 }()
69
70 _, err = WriteFileWithSftp(cfg, client)
71 if err != nil {
72 t.Error(err)
73 return
74 }
75
76 _, err = WriteFilesMultProjectsWithSftp(cfg, client)
77 if err != nil {
78 t.Error(err)
79 return
80 }
81
82 projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
83 if err != nil {
84 t.Error(err)
85 return
86 }
87
88 names := ""
89 for _, proj := range projects {
90 names += "_" + proj.Name
91 }
92
93 if names != "_test_mult_mult2" {
94 t.Errorf("not all projects created: %s", names)
95 return
96 }
97
98 close(done)
99
100 p, err := os.FindProcess(os.Getpid())
101 if err != nil {
102 t.Fatal(err)
103 return
104 }
105
106 err = p.Signal(os.Interrupt)
107 if err != nil {
108 t.Fatal(err)
109 return
110 }
111 <-time.After(10 * time.Millisecond)
112}
113
114func TestSshServerRsync(t *testing.T) {
115 opts := &slog.HandlerOptions{
116 AddSource: true,
117 Level: slog.LevelDebug,
118 }
119 logger := slog.New(
120 slog.NewTextHandler(os.Stdout, opts),
121 )
122 slog.SetDefault(logger)
123 dbpool := pgsdb.NewDBMemory(logger)
124 // setup test data
125 dbpool.SetupTestData()
126 st, err := storage.NewStorageMemory(map[string]map[string]string{})
127 if err != nil {
128 panic(err)
129 }
130 pubsub := NewPubsubChan()
131 defer func() {
132 _ = pubsub.Close()
133 }()
134 cfg := NewPgsConfig(logger, dbpool, st, pubsub)
135 done := make(chan error)
136 prometheus.DefaultRegisterer = prometheus.NewRegistry()
137 go StartSshServer(cfg, done)
138 // Hack to wait for startup
139 time.Sleep(time.Millisecond * 100)
140
141 user := GenerateUser()
142 key := utils.KeyForKeyText(user.signer.PublicKey())
143 // add user's pubkey to the default test account
144 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
145 ID: "nice-pubkey",
146 UserID: dbpool.Users[0].ID,
147 Key: key,
148 })
149
150 conn, err := user.NewClient()
151 if err != nil {
152 t.Error(err)
153 return
154 }
155 defer func() {
156 _ = conn.Close()
157 }()
158
159 // open an SFTP session over an existing ssh connection.
160 client, err := sftp.NewClient(conn)
161 if err != nil {
162 cfg.Logger.Error("could not create sftp client", "err", err)
163 panic(err)
164 }
165 defer func() {
166 _ = client.Close()
167 }()
168
169 name, err := os.MkdirTemp("", "rsync-")
170 if err != nil {
171 panic(err)
172 }
173
174 // remove the temporary directory at the end of the program
175 // defer os.RemoveAll(name)
176
177 block := &pem.Block{
178 Type: "OPENSSH PRIVATE KEY",
179 Bytes: user.privateKey,
180 }
181 keyFile := filepath.Join(name, "id_ed25519")
182 err = os.WriteFile(
183 keyFile,
184 pem.EncodeToMemory(block), 0600,
185 )
186 if err != nil {
187 t.Fatal(err)
188 }
189
190 index := "<!doctype html><html><body>index</body></html>"
191 err = os.WriteFile(
192 filepath.Join(name, "index.html"),
193 []byte(index), 0666,
194 )
195 if err != nil {
196 t.Fatal(err)
197 }
198
199 about := "<!doctype html><html><body>about</body></html>"
200 aboutFile := filepath.Join(name, "about.html")
201 err = os.WriteFile(
202 aboutFile,
203 []byte(about), 0666,
204 )
205 if err != nil {
206 t.Fatal(err)
207 }
208
209 contact := "<!doctype html><html><body>contact</body></html>"
210 err = os.WriteFile(
211 filepath.Join(name, "contact.html"),
212 []byte(contact), 0666,
213 )
214 if err != nil {
215 t.Fatal(err)
216 }
217
218 eCmd := fmt.Sprintf(
219 "ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
220 keyFile,
221 )
222
223 // copy files
224 cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
225 result, err := cmd.CombinedOutput()
226 if err != nil {
227 cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
228 t.Error(err)
229 return
230 }
231
232 // check it's there
233 fi, err := client.Lstat("/test/about.html")
234 if err != nil {
235 cfg.Logger.Error("could not get stat for file", "err", err)
236 t.Error("about.html not found")
237 return
238 }
239 if fi.Size() != 46 {
240 cfg.Logger.Error("about.html wrong size", "size", fi.Size())
241 t.Error("about.html wrong size")
242 return
243 }
244
245 // remove about file
246 _ = os.Remove(aboutFile)
247
248 // copy files with delete
249 delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
250 result, err = delCmd.CombinedOutput()
251 if err != nil {
252 cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
253 t.Error(err)
254 return
255 }
256
257 // check it's not there
258 _, err = client.Lstat("/test/about.html")
259 if err == nil {
260 cfg.Logger.Error("file still exists")
261 t.Error("about.html found")
262 return
263 }
264
265 dlName, err := os.MkdirTemp("", "rsync-download")
266 if err != nil {
267 panic(err)
268 }
269 defer func() {
270 _ = os.RemoveAll(dlName)
271 }()
272 // download files
273 downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
274 result, err = downloadCmd.CombinedOutput()
275 if err != nil {
276 cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
277 t.Error(err)
278 return
279 }
280 // check contents
281 idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
282 if err != nil {
283 cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
284 t.Error(err)
285 return
286 }
287 if string(idx) != index {
288 t.Error("downloaded index.html file does not match original")
289 return
290 }
291 _, err = os.ReadFile(filepath.Join(dlName, "about.html"))
292 if err == nil {
293 cfg.Logger.Error("about file should not exist", "file", "about.html")
294 t.Error(err)
295 return
296 }
297 cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
298 if err != nil {
299 cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
300 t.Error(err)
301 return
302 }
303 if string(cnt) != contact {
304 t.Error("downloaded contact.html file does not match original")
305 return
306 }
307
308 close(done)
309
310 p, err := os.FindProcess(os.Getpid())
311 if err != nil {
312 t.Fatal(err)
313 return
314 }
315
316 err = p.Signal(os.Interrupt)
317 if err != nil {
318 t.Fatal(err)
319 return
320 }
321 <-time.After(10 * time.Millisecond)
322}
323
324type UserSSH struct {
325 username string
326 signer ssh.Signer
327 privateKey []byte
328}
329
330func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
331 return &UserSSH{
332 username: username,
333 signer: signer,
334 }
335}
336
337func (s UserSSH) Public() string {
338 pubkey := s.signer.PublicKey()
339 return string(ssh.MarshalAuthorizedKey(pubkey))
340}
341
342func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
343 res, err := s.Cmd(client, patch, cmd)
344 if err != nil {
345 panic(err)
346 }
347 return res
348}
349
350func (s UserSSH) NewClient() (*ssh.Client, error) {
351 host := "localhost:2222"
352
353 config := &ssh.ClientConfig{
354 User: s.username,
355 Auth: []ssh.AuthMethod{
356 ssh.PublicKeys(s.signer),
357 },
358 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
359 }
360
361 client, err := ssh.Dial("tcp", host, config)
362 return client, err
363}
364
365func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
366 session, err := client.NewSession()
367 if err != nil {
368 return "", err
369 }
370 defer func() {
371 _ = session.Close()
372 }()
373
374 stdinPipe, err := session.StdinPipe()
375 if err != nil {
376 return "", err
377 }
378
379 stdoutPipe, err := session.StdoutPipe()
380 if err != nil {
381 return "", err
382 }
383
384 if err := session.Start(cmd); err != nil {
385 return "", err
386 }
387
388 if patch != nil {
389 _, err = stdinPipe.Write(patch)
390 if err != nil {
391 return "", err
392 }
393 }
394
395 err = stdinPipe.Close()
396 if err != nil {
397 return "", err
398 }
399
400 if err := session.Wait(); err != nil {
401 return "", err
402 }
403
404 buf := new(strings.Builder)
405 _, err = io.Copy(buf, stdoutPipe)
406 if err != nil {
407 return "", err
408 }
409
410 return buf.String(), nil
411}
412
413func GenerateUser() UserSSH {
414 _, userKey, err := ed25519.GenerateKey(rand.Reader)
415 if err != nil {
416 panic(err)
417 }
418
419 b, err := ssh.MarshalPrivateKey(userKey, "")
420 if err != nil {
421 panic(err)
422 }
423
424 userSigner, err := ssh.NewSignerFromKey(userKey)
425 if err != nil {
426 panic(err)
427 }
428
429 return UserSSH{
430 username: "testuser",
431 signer: userSigner,
432 privateKey: b.Bytes,
433 }
434}
435
436func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
437 // open an SFTP session over an existing ssh connection.
438 client, err := sftp.NewClient(conn)
439 if err != nil {
440 cfg.Logger.Error("could not create sftp client", "err", err)
441 return nil, err
442 }
443 defer func() {
444 _ = client.Close()
445 }()
446
447 f, err := client.Create("test/hello.txt")
448 if err != nil {
449 cfg.Logger.Error("could not create file", "err", err)
450 return nil, err
451 }
452 if _, err := f.Write([]byte("Hello world!")); err != nil {
453 cfg.Logger.Error("could not write to file", "err", err)
454 return nil, err
455 }
456
457 cfg.Logger.Info("closing", "err", f.Close())
458
459 // check it's there
460 fi, err := client.Lstat("test/hello.txt")
461 if err != nil {
462 cfg.Logger.Error("could not get stat for file", "err", err)
463 return nil, err
464 }
465
466 return &fi, nil
467}
468
469func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
470 // open an SFTP session over an existing ssh connection.
471 client, err := sftp.NewClient(conn)
472 if err != nil {
473 cfg.Logger.Error("could not create sftp client", "err", err)
474 return nil, err
475 }
476 defer func() {
477 _ = client.Close()
478 }()
479
480 f, err := client.Create("mult/hello.txt")
481 if err != nil {
482 cfg.Logger.Error("could not create file", "err", err)
483 return nil, err
484 }
485 if _, err := f.Write([]byte("Hello world!")); err != nil {
486 cfg.Logger.Error("could not write to file", "err", err)
487 return nil, err
488 }
489
490 f, err = client.Create("mult2/hello.txt")
491 if err != nil {
492 cfg.Logger.Error("could not create file", "err", err)
493 return nil, err
494 }
495 if _, err := f.Write([]byte("Hello world!")); err != nil {
496 cfg.Logger.Error("could not write to file", "err", err)
497 return nil, err
498 }
499
500 cfg.Logger.Info("closing", "err", f.Close())
501
502 // check it's there
503 fi, err := client.Lstat("test/hello.txt")
504 if err != nil {
505 cfg.Logger.Error("could not get stat for file", "err", err)
506 return nil, err
507 }
508
509 return &fi, nil
510}