Eric Bower
·
2026-02-26
ssh_test.go
1package pgs
2
3import (
4 "context"
5 "crypto/ed25519"
6 "crypto/rand"
7 "encoding/pem"
8 "fmt"
9 "io"
10 "log/slog"
11 "net"
12 "os"
13 "os/exec"
14 "path/filepath"
15 "strings"
16 "testing"
17 "time"
18
19 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
20 "github.com/picosh/pico/pkg/db"
21 "github.com/picosh/pico/pkg/pssh"
22 "github.com/picosh/pico/pkg/shared"
23 "github.com/picosh/pico/pkg/shared/storage"
24 "github.com/pkg/sftp"
25 "github.com/prometheus/client_golang/prometheus"
26 "golang.org/x/crypto/ssh"
27)
28
29func StartSshServerForTesting(cfg *PgsConfig, killCh chan error, readyCh chan *pssh.SSHServer) {
30 ctx, cancel := context.WithCancel(context.Background())
31 defer func() {
32 // Cancel is deferred to avoid being called in error path
33 // It will be called from the cleanup goroutine
34 _ = cancel
35 }()
36
37 cacheClearingQueue := make(chan string, 100)
38 logger := cfg.Logger
39
40 server, err := createSshServer(cfg, ctx, cacheClearingQueue)
41 if err != nil {
42 logger.Error("failed to create ssh server", "err", err.Error())
43 cancel() // Clean up if server creation fails
44 readyCh <- nil
45 return
46 }
47
48 logger.Info("Starting SSH server", "addr", server.Config.ListenAddr)
49
50 // Signal that server is ready once ListenAndServe starts
51 go func() {
52 if err = server.ListenAndServe(); err != nil {
53 logger.Error("serve", "err", err.Error())
54 }
55 }()
56
57 // Send server when listener is created (happens early in ListenAndServe)
58 readyCh <- server
59
60 go func() {
61 // Wait for kill signal and clean up
62 <-killCh
63 logger.Info("stopping ssh server")
64 cancel()
65 }()
66}
67
68func TestSshServerSftp(t *testing.T) {
69 opts := &slog.HandlerOptions{
70 AddSource: true,
71 Level: slog.LevelDebug,
72 }
73 logger := slog.New(
74 slog.NewTextHandler(os.Stdout, opts),
75 )
76 slog.SetDefault(logger)
77 dbpool := pgsdb.NewDBMemory(logger)
78 // setup test data
79 dbpool.SetupTestData()
80 st, err := storage.NewStorageMemory(map[string]map[string]string{})
81 if err != nil {
82 panic(err)
83 }
84 pubsub := NewPubsubChan()
85 defer func() {
86 _ = pubsub.Close()
87 }()
88
89 // Use dynamic port for tests to avoid port conflicts
90 _ = os.Setenv("PGS_SSH_PORT", "0")
91
92 cfg := NewPgsConfig(logger, dbpool, st, pubsub)
93 done := make(chan error)
94 readyCh := make(chan *pssh.SSHServer)
95 prometheus.DefaultRegisterer = prometheus.NewRegistry()
96
97 go StartSshServerForTesting(cfg, done, readyCh)
98
99 // Wait for server to be ready
100 server := <-readyCh
101 if server == nil {
102 t.Fatal("failed to create ssh server")
103 }
104
105 // Wait for listener to be created
106 var actualAddr string
107 for i := 0; i < 100; i++ {
108 server.Mu.Lock()
109 listener := server.Listener
110 server.Mu.Unlock()
111 if listener != nil {
112 actualAddr = listener.Addr().String()
113 break
114 }
115 time.Sleep(10 * time.Millisecond)
116 }
117
118 if actualAddr == "" {
119 t.Fatal("server listener not ready")
120 }
121
122 user := GenerateUser()
123 // add user's pubkey to the default test account
124 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
125 ID: "nice-pubkey",
126 UserID: dbpool.Users[0].ID,
127 Key: shared.KeyForKeyText(user.signer.PublicKey()),
128 })
129
130 client, err := user.NewClientAddr(actualAddr)
131 if err != nil {
132 t.Error(err)
133 return
134 }
135 defer func() {
136 _ = client.Close()
137 }()
138
139 _, err = WriteFileWithSftp(cfg, client)
140 if err != nil {
141 t.Error(err)
142 return
143 }
144
145 _, err = WriteFilesMultProjectsWithSftp(cfg, client)
146 if err != nil {
147 t.Error(err)
148 return
149 }
150
151 projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
152 if err != nil {
153 t.Error(err)
154 return
155 }
156
157 names := ""
158 for _, proj := range projects {
159 names += "_" + proj.Name
160 }
161
162 if names != "_test_mult_mult2" {
163 t.Errorf("not all projects created: %s", names)
164 return
165 }
166
167 close(done)
168 time.Sleep(100 * time.Millisecond)
169}
170
171func TestSshServerRsync(t *testing.T) {
172 opts := &slog.HandlerOptions{
173 AddSource: true,
174 Level: slog.LevelDebug,
175 }
176 logger := slog.New(
177 slog.NewTextHandler(os.Stdout, opts),
178 )
179 slog.SetDefault(logger)
180 dbpool := pgsdb.NewDBMemory(logger)
181 // setup test data
182 dbpool.SetupTestData()
183 st, err := storage.NewStorageMemory(map[string]map[string]string{})
184 if err != nil {
185 panic(err)
186 }
187 pubsub := NewPubsubChan()
188 defer func() {
189 _ = pubsub.Close()
190 }()
191
192 // Use dynamic port for tests to avoid port conflicts
193 _ = os.Setenv("PGS_SSH_PORT", "0")
194
195 cfg := NewPgsConfig(logger, dbpool, st, pubsub)
196 done := make(chan error)
197 readyCh := make(chan *pssh.SSHServer)
198 prometheus.DefaultRegisterer = prometheus.NewRegistry()
199
200 go StartSshServerForTesting(cfg, done, readyCh)
201
202 // Wait for server to be ready
203 server := <-readyCh
204 if server == nil {
205 t.Fatal("failed to create ssh server")
206 }
207
208 // Wait for listener to be created
209 var actualAddr string
210 for i := 0; i < 100; i++ {
211 server.Mu.Lock()
212 listener := server.Listener
213 server.Mu.Unlock()
214 if listener != nil {
215 actualAddr = listener.Addr().String()
216 break
217 }
218 time.Sleep(10 * time.Millisecond)
219 }
220
221 if actualAddr == "" {
222 t.Fatal("server listener not ready")
223 }
224
225 user := GenerateUser()
226 key := shared.KeyForKeyText(user.signer.PublicKey())
227 // add user's pubkey to the default test account
228 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
229 ID: "nice-pubkey",
230 UserID: dbpool.Users[0].ID,
231 Key: key,
232 })
233
234 conn, err := user.NewClientAddr(actualAddr)
235 if err != nil {
236 t.Error(err)
237 return
238 }
239 defer func() {
240 _ = conn.Close()
241 }()
242
243 // open an SFTP session over an existing ssh connection.
244 client, err := sftp.NewClient(conn)
245 if err != nil {
246 cfg.Logger.Error("could not create sftp client", "err", err)
247 panic(err)
248 }
249 defer func() {
250 _ = client.Close()
251 }()
252
253 name, err := os.MkdirTemp("", "rsync-")
254 if err != nil {
255 panic(err)
256 }
257
258 // remove the temporary directory at the end of the program
259 // defer os.RemoveAll(name)
260
261 block := &pem.Block{
262 Type: "OPENSSH PRIVATE KEY",
263 Bytes: user.privateKey,
264 }
265 keyFile := filepath.Join(name, "id_ed25519")
266 err = os.WriteFile(
267 keyFile,
268 pem.EncodeToMemory(block), 0600,
269 )
270 if err != nil {
271 t.Fatal(err)
272 }
273
274 index := "<!doctype html><html><body>index</body></html>"
275 err = os.WriteFile(
276 filepath.Join(name, "index.html"),
277 []byte(index), 0666,
278 )
279 if err != nil {
280 t.Fatal(err)
281 }
282
283 about := "<!doctype html><html><body>about</body></html>"
284 aboutFile := filepath.Join(name, "about.html")
285 err = os.WriteFile(
286 aboutFile,
287 []byte(about), 0666,
288 )
289 if err != nil {
290 t.Fatal(err)
291 }
292
293 contact := "<!doctype html><html><body>contact</body></html>"
294 err = os.WriteFile(
295 filepath.Join(name, "contact.html"),
296 []byte(contact), 0666,
297 )
298 if err != nil {
299 t.Fatal(err)
300 }
301
302 // Extract port from actualAddr (format: "0.0.0.0:XXXXX")
303 _, port, err := net.SplitHostPort(actualAddr)
304 if err != nil {
305 t.Fatalf("failed to parse server address: %v", err)
306 }
307 // Use localhost for rsync (works regardless of IPv4/IPv6 binding)
308 host := "localhost"
309
310 eCmd := fmt.Sprintf(
311 "ssh -p %s -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
312 port,
313 keyFile,
314 )
315
316 // copy files
317 cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", host+":/test")
318 result, err := cmd.CombinedOutput()
319 if err != nil {
320 cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
321 t.Error(err)
322 return
323 }
324
325 // check it's there
326 fi, err := client.Lstat("/test/about.html")
327 if err != nil {
328 cfg.Logger.Error("could not get stat for file", "err", err)
329 t.Error("about.html not found")
330 return
331 }
332 if fi.Size() != 46 {
333 cfg.Logger.Error("about.html wrong size", "size", fi.Size())
334 t.Error("about.html wrong size")
335 return
336 }
337
338 // remove about file
339 _ = os.Remove(aboutFile)
340
341 // copy files with delete
342 delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
343 result, err = delCmd.CombinedOutput()
344 if err != nil {
345 cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
346 t.Error(err)
347 return
348 }
349
350 // check it's not there
351 _, err = client.Lstat("/test/about.html")
352 if err == nil {
353 cfg.Logger.Error("file still exists")
354 t.Error("about.html found")
355 return
356 }
357
358 dlName, err := os.MkdirTemp("", "rsync-download")
359 if err != nil {
360 panic(err)
361 }
362 defer func() {
363 _ = os.RemoveAll(dlName)
364 }()
365 // download files
366 downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, host+":/test/", dlName+"/")
367 result, err = downloadCmd.CombinedOutput()
368 if err != nil {
369 cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
370 t.Error(err)
371 return
372 }
373 // check contents
374 idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
375 if err != nil {
376 cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
377 t.Error(err)
378 return
379 }
380 if string(idx) != index {
381 t.Error("downloaded index.html file does not match original")
382 return
383 }
384 _, err = os.ReadFile(filepath.Join(dlName, "about.html"))
385 if err == nil {
386 cfg.Logger.Error("about file should not exist", "file", "about.html")
387 t.Error(err)
388 return
389 }
390 cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
391 if err != nil {
392 cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
393 t.Error(err)
394 return
395 }
396 if string(cnt) != contact {
397 t.Error("downloaded contact.html file does not match original")
398 return
399 }
400
401 close(done)
402 time.Sleep(100 * time.Millisecond)
403}
404
405type UserSSH struct {
406 username string
407 signer ssh.Signer
408 privateKey []byte
409}
410
411func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
412 return &UserSSH{
413 username: username,
414 signer: signer,
415 }
416}
417
418func (s UserSSH) Public() string {
419 pubkey := s.signer.PublicKey()
420 return string(ssh.MarshalAuthorizedKey(pubkey))
421}
422
423func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
424 res, err := s.Cmd(client, patch, cmd)
425 if err != nil {
426 panic(err)
427 }
428 return res
429}
430
431func (s UserSSH) NewClientAddr(addr string) (*ssh.Client, error) {
432 config := &ssh.ClientConfig{
433 User: s.username,
434 Auth: []ssh.AuthMethod{
435 ssh.PublicKeys(s.signer),
436 },
437 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
438 }
439
440 client, err := ssh.Dial("tcp", addr, config)
441 return client, err
442}
443
444func (s UserSSH) NewClient() (*ssh.Client, error) {
445 // Default to localhost:2222 for backward compatibility
446 return s.NewClientAddr("localhost:2222")
447}
448
449func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
450 session, err := client.NewSession()
451 if err != nil {
452 return "", err
453 }
454 defer func() {
455 _ = session.Close()
456 }()
457
458 stdinPipe, err := session.StdinPipe()
459 if err != nil {
460 return "", err
461 }
462
463 stdoutPipe, err := session.StdoutPipe()
464 if err != nil {
465 return "", err
466 }
467
468 if err := session.Start(cmd); err != nil {
469 return "", err
470 }
471
472 if patch != nil {
473 _, err = stdinPipe.Write(patch)
474 if err != nil {
475 return "", err
476 }
477 }
478
479 err = stdinPipe.Close()
480 if err != nil {
481 return "", err
482 }
483
484 if err := session.Wait(); err != nil {
485 return "", err
486 }
487
488 buf := new(strings.Builder)
489 _, err = io.Copy(buf, stdoutPipe)
490 if err != nil {
491 return "", err
492 }
493
494 return buf.String(), nil
495}
496
497func GenerateUser() UserSSH {
498 _, userKey, err := ed25519.GenerateKey(rand.Reader)
499 if err != nil {
500 panic(err)
501 }
502
503 b, err := ssh.MarshalPrivateKey(userKey, "")
504 if err != nil {
505 panic(err)
506 }
507
508 userSigner, err := ssh.NewSignerFromKey(userKey)
509 if err != nil {
510 panic(err)
511 }
512
513 return UserSSH{
514 username: "testuser",
515 signer: userSigner,
516 privateKey: b.Bytes,
517 }
518}
519
520func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
521 // open an SFTP session over an existing ssh connection.
522 client, err := sftp.NewClient(conn)
523 if err != nil {
524 cfg.Logger.Error("could not create sftp client", "err", err)
525 return nil, err
526 }
527 defer func() {
528 _ = client.Close()
529 }()
530
531 f, err := client.Create("test/hello.txt")
532 if err != nil {
533 cfg.Logger.Error("could not create file", "err", err)
534 return nil, err
535 }
536 if _, err := f.Write([]byte("Hello world!")); err != nil {
537 cfg.Logger.Error("could not write to file", "err", err)
538 return nil, err
539 }
540
541 cfg.Logger.Info("closing", "err", f.Close())
542
543 // check it's there
544 fi, err := client.Lstat("test/hello.txt")
545 if err != nil {
546 cfg.Logger.Error("could not get stat for file", "err", err)
547 return nil, err
548 }
549
550 return &fi, nil
551}
552
553func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
554 // open an SFTP session over an existing ssh connection.
555 client, err := sftp.NewClient(conn)
556 if err != nil {
557 cfg.Logger.Error("could not create sftp client", "err", err)
558 return nil, err
559 }
560 defer func() {
561 _ = client.Close()
562 }()
563
564 f, err := client.Create("mult/hello.txt")
565 if err != nil {
566 cfg.Logger.Error("could not create file", "err", err)
567 return nil, err
568 }
569 if _, err := f.Write([]byte("Hello world!")); err != nil {
570 cfg.Logger.Error("could not write to file", "err", err)
571 return nil, err
572 }
573
574 f, err = client.Create("mult2/hello.txt")
575 if err != nil {
576 cfg.Logger.Error("could not create file", "err", err)
577 return nil, err
578 }
579 if _, err := f.Write([]byte("Hello world!")); err != nil {
580 cfg.Logger.Error("could not write to file", "err", err)
581 return nil, err
582 }
583
584 cfg.Logger.Info("closing", "err", f.Close())
585
586 // check it's there
587 fi, err := client.Lstat("test/hello.txt")
588 if err != nil {
589 cfg.Logger.Error("could not get stat for file", "err", err)
590 return nil, err
591 }
592
593 return &fi, nil
594}