repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

commit
da9a246
parent
391c4f9
author
Eric Bower
date
2026-02-26 09:04:53 -0500 EST
refactor(pgs): use dynamic port for ssh tests
2 files changed,  +119, -46
M pkg/apps/pgs/ssh.go
+12, -5
 1@@ -17,16 +17,12 @@ import (
 2 	"github.com/picosh/pico/pkg/tunkit"
 3 )
 4 
 5-func StartSshServer(cfg *PgsConfig, killCh chan error) {
 6+func createSshServer(cfg *PgsConfig, ctx context.Context, cacheClearingQueue chan string) (*pssh.SSHServer, error) {
 7 	host := shared.GetEnv("PGS_HOST", "0.0.0.0")
 8 	port := shared.GetEnv("PGS_SSH_PORT", "2222")
 9 	promPort := shared.GetEnv("PGS_PROM_PORT", "9222")
10 	logger := cfg.Logger
11 
12-	ctx, cancel := context.WithCancel(context.Background())
13-	defer cancel()
14-
15-	cacheClearingQueue := make(chan string, 100)
16 	handler := NewUploadAssetHandler(
17 		cfg,
18 		cacheClearingQueue,
19@@ -68,6 +64,17 @@ func StartSshServer(cfg *PgsConfig, killCh chan error) {
20 		},
21 	)
22 
23+	return server, err
24+}
25+
26+func StartSshServer(cfg *PgsConfig, killCh chan error) {
27+	ctx, cancel := context.WithCancel(context.Background())
28+	defer cancel()
29+
30+	cacheClearingQueue := make(chan string, 100)
31+	logger := cfg.Logger
32+
33+	server, err := createSshServer(cfg, ctx, cacheClearingQueue)
34 	if err != nil {
35 		logger.Error("failed to create ssh server", "err", err.Error())
36 		os.Exit(1)
M pkg/apps/pgs/ssh_test.go
+107, -41
  1@@ -1,12 +1,14 @@
  2 package pgs
  3 
  4 import (
  5+	"context"
  6 	"crypto/ed25519"
  7 	"crypto/rand"
  8 	"encoding/pem"
  9 	"fmt"
 10 	"io"
 11 	"log/slog"
 12+	"net"
 13 	"os"
 14 	"os/exec"
 15 	"path/filepath"
 16@@ -16,6 +18,7 @@ import (
 17 
 18 	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 19 	"github.com/picosh/pico/pkg/db"
 20+	"github.com/picosh/pico/pkg/pssh"
 21 	"github.com/picosh/pico/pkg/shared"
 22 	"github.com/picosh/pico/pkg/shared/storage"
 23 	"github.com/pkg/sftp"
 24@@ -23,6 +26,41 @@ import (
 25 	"golang.org/x/crypto/ssh"
 26 )
 27 
 28+func StartSshServerForTesting(cfg *PgsConfig, killCh chan error) *pssh.SSHServer {
 29+	ctx, cancel := context.WithCancel(context.Background())
 30+	defer func() {
 31+		// Cancel is deferred to avoid being called in error path
 32+		// It will be called from the cleanup goroutine
 33+		_ = cancel
 34+	}()
 35+
 36+	cacheClearingQueue := make(chan string, 100)
 37+	logger := cfg.Logger
 38+
 39+	server, err := createSshServer(cfg, ctx, cacheClearingQueue)
 40+	if err != nil {
 41+		logger.Error("failed to create ssh server", "err", err.Error())
 42+		cancel() // Clean up if server creation fails
 43+		return nil
 44+	}
 45+
 46+	logger.Info("Starting SSH server", "addr", server.Config.ListenAddr)
 47+	go func() {
 48+		if err = server.ListenAndServe(); err != nil {
 49+			logger.Error("serve", "err", err.Error())
 50+		}
 51+	}()
 52+
 53+	go func() {
 54+		// Wait for kill signal and clean up
 55+		<-killCh
 56+		logger.Info("stopping ssh server")
 57+		cancel()
 58+	}()
 59+
 60+	return server
 61+}
 62+
 63 func TestSshServerSftp(t *testing.T) {
 64 	opts := &slog.HandlerOptions{
 65 		AddSource: true,
 66@@ -43,12 +81,32 @@ func TestSshServerSftp(t *testing.T) {
 67 	defer func() {
 68 		_ = pubsub.Close()
 69 	}()
 70+
 71+	// Use dynamic port for tests to avoid port conflicts
 72+	_ = os.Setenv("PGS_SSH_PORT", "0")
 73+
 74 	cfg := NewPgsConfig(logger, dbpool, st, pubsub)
 75 	done := make(chan error)
 76 	prometheus.DefaultRegisterer = prometheus.NewRegistry()
 77-	go StartSshServer(cfg, done)
 78-	// Hack to wait for startup
 79-	time.Sleep(time.Millisecond * 100)
 80+
 81+	var server *pssh.SSHServer
 82+	go func() {
 83+		server = StartSshServerForTesting(cfg, done)
 84+	}()
 85+
 86+	// Wait for server to be ready and get the actual listening address
 87+	var actualAddr string
 88+	for i := 0; i < 100; i++ {
 89+		if server != nil && server.Listener != nil {
 90+			actualAddr = server.Listener.Addr().String()
 91+			break
 92+		}
 93+		time.Sleep(20 * time.Millisecond)
 94+	}
 95+
 96+	if actualAddr == "" {
 97+		t.Fatal("server listener not ready")
 98+	}
 99 
100 	user := GenerateUser()
101 	// add user's pubkey to the default test account
102@@ -58,7 +116,7 @@ func TestSshServerSftp(t *testing.T) {
103 		Key:    shared.KeyForKeyText(user.signer.PublicKey()),
104 	})
105 
106-	client, err := user.NewClient()
107+	client, err := user.NewClientAddr(actualAddr)
108 	if err != nil {
109 		t.Error(err)
110 		return
111@@ -96,19 +154,7 @@ func TestSshServerSftp(t *testing.T) {
112 	}
113 
114 	close(done)
115-
116-	p, err := os.FindProcess(os.Getpid())
117-	if err != nil {
118-		t.Fatal(err)
119-		return
120-	}
121-
122-	err = p.Signal(os.Interrupt)
123-	if err != nil {
124-		t.Fatal(err)
125-		return
126-	}
127-	<-time.After(10 * time.Millisecond)
128+	time.Sleep(100 * time.Millisecond)
129 }
130 
131 func TestSshServerRsync(t *testing.T) {
132@@ -131,12 +177,32 @@ func TestSshServerRsync(t *testing.T) {
133 	defer func() {
134 		_ = pubsub.Close()
135 	}()
136+
137+	// Use dynamic port for tests to avoid port conflicts
138+	_ = os.Setenv("PGS_SSH_PORT", "0")
139+
140 	cfg := NewPgsConfig(logger, dbpool, st, pubsub)
141 	done := make(chan error)
142 	prometheus.DefaultRegisterer = prometheus.NewRegistry()
143-	go StartSshServer(cfg, done)
144-	// Hack to wait for startup
145-	time.Sleep(time.Millisecond * 100)
146+
147+	var server *pssh.SSHServer
148+	go func() {
149+		server = StartSshServerForTesting(cfg, done)
150+	}()
151+
152+	// Wait for server to be ready and get the actual listening address
153+	var actualAddr string
154+	for i := 0; i < 100; i++ {
155+		if server != nil && server.Listener != nil {
156+			actualAddr = server.Listener.Addr().String()
157+			break
158+		}
159+		time.Sleep(20 * time.Millisecond)
160+	}
161+
162+	if actualAddr == "" {
163+		t.Fatal("server listener not ready")
164+	}
165 
166 	user := GenerateUser()
167 	key := shared.KeyForKeyText(user.signer.PublicKey())
168@@ -147,7 +213,7 @@ func TestSshServerRsync(t *testing.T) {
169 		Key:    key,
170 	})
171 
172-	conn, err := user.NewClient()
173+	conn, err := user.NewClientAddr(actualAddr)
174 	if err != nil {
175 		t.Error(err)
176 		return
177@@ -215,13 +281,22 @@ func TestSshServerRsync(t *testing.T) {
178 		t.Fatal(err)
179 	}
180 
181+	// Extract port from actualAddr (format: "0.0.0.0:XXXXX")
182+	_, port, err := net.SplitHostPort(actualAddr)
183+	if err != nil {
184+		t.Fatalf("failed to parse server address: %v", err)
185+	}
186+	// Use localhost for rsync (works regardless of IPv4/IPv6 binding)
187+	host := "localhost"
188+
189 	eCmd := fmt.Sprintf(
190-		"ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
191+		"ssh -p %s -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
192+		port,
193 		keyFile,
194 	)
195 
196 	// copy files
197-	cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
198+	cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", host+":/test")
199 	result, err := cmd.CombinedOutput()
200 	if err != nil {
201 		cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
202@@ -270,7 +345,7 @@ func TestSshServerRsync(t *testing.T) {
203 		_ = os.RemoveAll(dlName)
204 	}()
205 	// download files
206-	downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
207+	downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, host+":/test/", dlName+"/")
208 	result, err = downloadCmd.CombinedOutput()
209 	if err != nil {
210 		cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
211@@ -306,19 +381,7 @@ func TestSshServerRsync(t *testing.T) {
212 	}
213 
214 	close(done)
215-
216-	p, err := os.FindProcess(os.Getpid())
217-	if err != nil {
218-		t.Fatal(err)
219-		return
220-	}
221-
222-	err = p.Signal(os.Interrupt)
223-	if err != nil {
224-		t.Fatal(err)
225-		return
226-	}
227-	<-time.After(10 * time.Millisecond)
228+	time.Sleep(100 * time.Millisecond)
229 }
230 
231 type UserSSH struct {
232@@ -347,9 +410,7 @@ func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
233 	return res
234 }
235 
236-func (s UserSSH) NewClient() (*ssh.Client, error) {
237-	host := "localhost:2222"
238-
239+func (s UserSSH) NewClientAddr(addr string) (*ssh.Client, error) {
240 	config := &ssh.ClientConfig{
241 		User: s.username,
242 		Auth: []ssh.AuthMethod{
243@@ -358,10 +419,15 @@ func (s UserSSH) NewClient() (*ssh.Client, error) {
244 		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
245 	}
246 
247-	client, err := ssh.Dial("tcp", host, config)
248+	client, err := ssh.Dial("tcp", addr, config)
249 	return client, err
250 }
251 
252+func (s UserSSH) NewClient() (*ssh.Client, error) {
253+	// Default to localhost:2222 for backward compatibility
254+	return s.NewClientAddr("localhost:2222")
255+}
256+
257 func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
258 	session, err := client.NewSession()
259 	if err != nil {