repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2026-02-26

ssh_test.go

  1package pgs
  2
  3import (
  4	"context"
  5	"crypto/ed25519"
  6	"crypto/rand"
  7	"encoding/pem"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"net"
 12	"os"
 13	"os/exec"
 14	"path/filepath"
 15	"strings"
 16	"testing"
 17	"time"
 18
 19	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 20	"github.com/picosh/pico/pkg/db"
 21	"github.com/picosh/pico/pkg/pssh"
 22	"github.com/picosh/pico/pkg/shared"
 23	"github.com/picosh/pico/pkg/shared/storage"
 24	"github.com/pkg/sftp"
 25	"github.com/prometheus/client_golang/prometheus"
 26	"golang.org/x/crypto/ssh"
 27)
 28
 29func StartSshServerForTesting(cfg *PgsConfig, killCh chan error, readyCh chan *pssh.SSHServer) {
 30	ctx, cancel := context.WithCancel(context.Background())
 31	defer func() {
 32		// Cancel is deferred to avoid being called in error path
 33		// It will be called from the cleanup goroutine
 34		_ = cancel
 35	}()
 36
 37	cacheClearingQueue := make(chan string, 100)
 38	logger := cfg.Logger
 39
 40	server, err := createSshServer(cfg, ctx, cacheClearingQueue)
 41	if err != nil {
 42		logger.Error("failed to create ssh server", "err", err.Error())
 43		cancel() // Clean up if server creation fails
 44		readyCh <- nil
 45		return
 46	}
 47
 48	logger.Info("Starting SSH server", "addr", server.Config.ListenAddr)
 49
 50	// Signal that server is ready once ListenAndServe starts
 51	go func() {
 52		if err = server.ListenAndServe(); err != nil {
 53			logger.Error("serve", "err", err.Error())
 54		}
 55	}()
 56
 57	// Send server when listener is created (happens early in ListenAndServe)
 58	readyCh <- server
 59
 60	go func() {
 61		// Wait for kill signal and clean up
 62		<-killCh
 63		logger.Info("stopping ssh server")
 64		cancel()
 65	}()
 66}
 67
 68func TestSshServerSftp(t *testing.T) {
 69	opts := &slog.HandlerOptions{
 70		AddSource: true,
 71		Level:     slog.LevelDebug,
 72	}
 73	logger := slog.New(
 74		slog.NewTextHandler(os.Stdout, opts),
 75	)
 76	slog.SetDefault(logger)
 77	dbpool := pgsdb.NewDBMemory(logger)
 78	// setup test data
 79	dbpool.SetupTestData()
 80	st, err := storage.NewStorageMemory(map[string]map[string]string{})
 81	if err != nil {
 82		panic(err)
 83	}
 84	pubsub := NewPubsubChan()
 85	defer func() {
 86		_ = pubsub.Close()
 87	}()
 88
 89	// Use dynamic port for tests to avoid port conflicts
 90	_ = os.Setenv("PGS_SSH_PORT", "0")
 91
 92	cfg := NewPgsConfig(logger, dbpool, st, pubsub)
 93	done := make(chan error)
 94	readyCh := make(chan *pssh.SSHServer)
 95	prometheus.DefaultRegisterer = prometheus.NewRegistry()
 96
 97	go StartSshServerForTesting(cfg, done, readyCh)
 98
 99	// Wait for server to be ready
100	server := <-readyCh
101	if server == nil {
102		t.Fatal("failed to create ssh server")
103	}
104
105	// Wait for listener to be created
106	var actualAddr string
107	for i := 0; i < 100; i++ {
108		server.Mu.Lock()
109		listener := server.Listener
110		server.Mu.Unlock()
111		if listener != nil {
112			actualAddr = listener.Addr().String()
113			break
114		}
115		time.Sleep(10 * time.Millisecond)
116	}
117
118	if actualAddr == "" {
119		t.Fatal("server listener not ready")
120	}
121
122	user := GenerateUser()
123	// add user's pubkey to the default test account
124	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
125		ID:     "nice-pubkey",
126		UserID: dbpool.Users[0].ID,
127		Key:    shared.KeyForKeyText(user.signer.PublicKey()),
128	})
129
130	client, err := user.NewClientAddr(actualAddr)
131	if err != nil {
132		t.Error(err)
133		return
134	}
135	defer func() {
136		_ = client.Close()
137	}()
138
139	_, err = WriteFileWithSftp(cfg, client)
140	if err != nil {
141		t.Error(err)
142		return
143	}
144
145	_, err = WriteFilesMultProjectsWithSftp(cfg, client)
146	if err != nil {
147		t.Error(err)
148		return
149	}
150
151	projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
152	if err != nil {
153		t.Error(err)
154		return
155	}
156
157	names := ""
158	for _, proj := range projects {
159		names += "_" + proj.Name
160	}
161
162	if names != "_test_mult_mult2" {
163		t.Errorf("not all projects created: %s", names)
164		return
165	}
166
167	close(done)
168	time.Sleep(100 * time.Millisecond)
169}
170
171func TestSshServerRsync(t *testing.T) {
172	opts := &slog.HandlerOptions{
173		AddSource: true,
174		Level:     slog.LevelDebug,
175	}
176	logger := slog.New(
177		slog.NewTextHandler(os.Stdout, opts),
178	)
179	slog.SetDefault(logger)
180	dbpool := pgsdb.NewDBMemory(logger)
181	// setup test data
182	dbpool.SetupTestData()
183	st, err := storage.NewStorageMemory(map[string]map[string]string{})
184	if err != nil {
185		panic(err)
186	}
187	pubsub := NewPubsubChan()
188	defer func() {
189		_ = pubsub.Close()
190	}()
191
192	// Use dynamic port for tests to avoid port conflicts
193	_ = os.Setenv("PGS_SSH_PORT", "0")
194
195	cfg := NewPgsConfig(logger, dbpool, st, pubsub)
196	done := make(chan error)
197	readyCh := make(chan *pssh.SSHServer)
198	prometheus.DefaultRegisterer = prometheus.NewRegistry()
199
200	go StartSshServerForTesting(cfg, done, readyCh)
201
202	// Wait for server to be ready
203	server := <-readyCh
204	if server == nil {
205		t.Fatal("failed to create ssh server")
206	}
207
208	// Wait for listener to be created
209	var actualAddr string
210	for i := 0; i < 100; i++ {
211		server.Mu.Lock()
212		listener := server.Listener
213		server.Mu.Unlock()
214		if listener != nil {
215			actualAddr = listener.Addr().String()
216			break
217		}
218		time.Sleep(10 * time.Millisecond)
219	}
220
221	if actualAddr == "" {
222		t.Fatal("server listener not ready")
223	}
224
225	user := GenerateUser()
226	key := shared.KeyForKeyText(user.signer.PublicKey())
227	// add user's pubkey to the default test account
228	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
229		ID:     "nice-pubkey",
230		UserID: dbpool.Users[0].ID,
231		Key:    key,
232	})
233
234	conn, err := user.NewClientAddr(actualAddr)
235	if err != nil {
236		t.Error(err)
237		return
238	}
239	defer func() {
240		_ = conn.Close()
241	}()
242
243	// open an SFTP session over an existing ssh connection.
244	client, err := sftp.NewClient(conn)
245	if err != nil {
246		cfg.Logger.Error("could not create sftp client", "err", err)
247		panic(err)
248	}
249	defer func() {
250		_ = client.Close()
251	}()
252
253	name, err := os.MkdirTemp("", "rsync-")
254	if err != nil {
255		panic(err)
256	}
257
258	// remove the temporary directory at the end of the program
259	// defer os.RemoveAll(name)
260
261	block := &pem.Block{
262		Type:  "OPENSSH PRIVATE KEY",
263		Bytes: user.privateKey,
264	}
265	keyFile := filepath.Join(name, "id_ed25519")
266	err = os.WriteFile(
267		keyFile,
268		pem.EncodeToMemory(block), 0600,
269	)
270	if err != nil {
271		t.Fatal(err)
272	}
273
274	index := "<!doctype html><html><body>index</body></html>"
275	err = os.WriteFile(
276		filepath.Join(name, "index.html"),
277		[]byte(index), 0666,
278	)
279	if err != nil {
280		t.Fatal(err)
281	}
282
283	about := "<!doctype html><html><body>about</body></html>"
284	aboutFile := filepath.Join(name, "about.html")
285	err = os.WriteFile(
286		aboutFile,
287		[]byte(about), 0666,
288	)
289	if err != nil {
290		t.Fatal(err)
291	}
292
293	contact := "<!doctype html><html><body>contact</body></html>"
294	err = os.WriteFile(
295		filepath.Join(name, "contact.html"),
296		[]byte(contact), 0666,
297	)
298	if err != nil {
299		t.Fatal(err)
300	}
301
302	// Extract port from actualAddr (format: "0.0.0.0:XXXXX")
303	_, port, err := net.SplitHostPort(actualAddr)
304	if err != nil {
305		t.Fatalf("failed to parse server address: %v", err)
306	}
307	// Use localhost for rsync (works regardless of IPv4/IPv6 binding)
308	host := "localhost"
309
310	eCmd := fmt.Sprintf(
311		"ssh -p %s -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
312		port,
313		keyFile,
314	)
315
316	// copy files
317	cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", host+":/test")
318	result, err := cmd.CombinedOutput()
319	if err != nil {
320		cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
321		t.Error(err)
322		return
323	}
324
325	// check it's there
326	fi, err := client.Lstat("/test/about.html")
327	if err != nil {
328		cfg.Logger.Error("could not get stat for file", "err", err)
329		t.Error("about.html not found")
330		return
331	}
332	if fi.Size() != 46 {
333		cfg.Logger.Error("about.html wrong size", "size", fi.Size())
334		t.Error("about.html wrong size")
335		return
336	}
337
338	// remove about file
339	_ = os.Remove(aboutFile)
340
341	// copy files with delete
342	delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
343	result, err = delCmd.CombinedOutput()
344	if err != nil {
345		cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
346		t.Error(err)
347		return
348	}
349
350	// check it's not there
351	_, err = client.Lstat("/test/about.html")
352	if err == nil {
353		cfg.Logger.Error("file still exists")
354		t.Error("about.html found")
355		return
356	}
357
358	dlName, err := os.MkdirTemp("", "rsync-download")
359	if err != nil {
360		panic(err)
361	}
362	defer func() {
363		_ = os.RemoveAll(dlName)
364	}()
365	// download files
366	downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, host+":/test/", dlName+"/")
367	result, err = downloadCmd.CombinedOutput()
368	if err != nil {
369		cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
370		t.Error(err)
371		return
372	}
373	// check contents
374	idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
375	if err != nil {
376		cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
377		t.Error(err)
378		return
379	}
380	if string(idx) != index {
381		t.Error("downloaded index.html file does not match original")
382		return
383	}
384	_, err = os.ReadFile(filepath.Join(dlName, "about.html"))
385	if err == nil {
386		cfg.Logger.Error("about file should not exist", "file", "about.html")
387		t.Error(err)
388		return
389	}
390	cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
391	if err != nil {
392		cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
393		t.Error(err)
394		return
395	}
396	if string(cnt) != contact {
397		t.Error("downloaded contact.html file does not match original")
398		return
399	}
400
401	close(done)
402	time.Sleep(100 * time.Millisecond)
403}
404
405type UserSSH struct {
406	username   string
407	signer     ssh.Signer
408	privateKey []byte
409}
410
411func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
412	return &UserSSH{
413		username: username,
414		signer:   signer,
415	}
416}
417
418func (s UserSSH) Public() string {
419	pubkey := s.signer.PublicKey()
420	return string(ssh.MarshalAuthorizedKey(pubkey))
421}
422
423func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
424	res, err := s.Cmd(client, patch, cmd)
425	if err != nil {
426		panic(err)
427	}
428	return res
429}
430
431func (s UserSSH) NewClientAddr(addr string) (*ssh.Client, error) {
432	config := &ssh.ClientConfig{
433		User: s.username,
434		Auth: []ssh.AuthMethod{
435			ssh.PublicKeys(s.signer),
436		},
437		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
438	}
439
440	client, err := ssh.Dial("tcp", addr, config)
441	return client, err
442}
443
444func (s UserSSH) NewClient() (*ssh.Client, error) {
445	// Default to localhost:2222 for backward compatibility
446	return s.NewClientAddr("localhost:2222")
447}
448
449func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
450	session, err := client.NewSession()
451	if err != nil {
452		return "", err
453	}
454	defer func() {
455		_ = session.Close()
456	}()
457
458	stdinPipe, err := session.StdinPipe()
459	if err != nil {
460		return "", err
461	}
462
463	stdoutPipe, err := session.StdoutPipe()
464	if err != nil {
465		return "", err
466	}
467
468	if err := session.Start(cmd); err != nil {
469		return "", err
470	}
471
472	if patch != nil {
473		_, err = stdinPipe.Write(patch)
474		if err != nil {
475			return "", err
476		}
477	}
478
479	err = stdinPipe.Close()
480	if err != nil {
481		return "", err
482	}
483
484	if err := session.Wait(); err != nil {
485		return "", err
486	}
487
488	buf := new(strings.Builder)
489	_, err = io.Copy(buf, stdoutPipe)
490	if err != nil {
491		return "", err
492	}
493
494	return buf.String(), nil
495}
496
497func GenerateUser() UserSSH {
498	_, userKey, err := ed25519.GenerateKey(rand.Reader)
499	if err != nil {
500		panic(err)
501	}
502
503	b, err := ssh.MarshalPrivateKey(userKey, "")
504	if err != nil {
505		panic(err)
506	}
507
508	userSigner, err := ssh.NewSignerFromKey(userKey)
509	if err != nil {
510		panic(err)
511	}
512
513	return UserSSH{
514		username:   "testuser",
515		signer:     userSigner,
516		privateKey: b.Bytes,
517	}
518}
519
520func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
521	// open an SFTP session over an existing ssh connection.
522	client, err := sftp.NewClient(conn)
523	if err != nil {
524		cfg.Logger.Error("could not create sftp client", "err", err)
525		return nil, err
526	}
527	defer func() {
528		_ = client.Close()
529	}()
530
531	f, err := client.Create("test/hello.txt")
532	if err != nil {
533		cfg.Logger.Error("could not create file", "err", err)
534		return nil, err
535	}
536	if _, err := f.Write([]byte("Hello world!")); err != nil {
537		cfg.Logger.Error("could not write to file", "err", err)
538		return nil, err
539	}
540
541	cfg.Logger.Info("closing", "err", f.Close())
542
543	// check it's there
544	fi, err := client.Lstat("test/hello.txt")
545	if err != nil {
546		cfg.Logger.Error("could not get stat for file", "err", err)
547		return nil, err
548	}
549
550	return &fi, nil
551}
552
553func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
554	// open an SFTP session over an existing ssh connection.
555	client, err := sftp.NewClient(conn)
556	if err != nil {
557		cfg.Logger.Error("could not create sftp client", "err", err)
558		return nil, err
559	}
560	defer func() {
561		_ = client.Close()
562	}()
563
564	f, err := client.Create("mult/hello.txt")
565	if err != nil {
566		cfg.Logger.Error("could not create file", "err", err)
567		return nil, err
568	}
569	if _, err := f.Write([]byte("Hello world!")); err != nil {
570		cfg.Logger.Error("could not write to file", "err", err)
571		return nil, err
572	}
573
574	f, err = client.Create("mult2/hello.txt")
575	if err != nil {
576		cfg.Logger.Error("could not create file", "err", err)
577		return nil, err
578	}
579	if _, err := f.Write([]byte("Hello world!")); err != nil {
580		cfg.Logger.Error("could not write to file", "err", err)
581		return nil, err
582	}
583
584	cfg.Logger.Info("closing", "err", f.Close())
585
586	// check it's there
587	fi, err := client.Lstat("test/hello.txt")
588	if err != nil {
589		cfg.Logger.Error("could not get stat for file", "err", err)
590		return nil, err
591	}
592
593	return &fi, nil
594}