- commit
- da9a246
- parent
- 391c4f9
- author
- Eric Bower
- date
- 2026-02-26 09:04:53 -0500 EST
refactor(pgs): use dynamic port for ssh tests
2 files changed,
+119,
-46
+12,
-5
1@@ -17,16 +17,12 @@ import (
2 "github.com/picosh/pico/pkg/tunkit"
3 )
4
5-func StartSshServer(cfg *PgsConfig, killCh chan error) {
6+func createSshServer(cfg *PgsConfig, ctx context.Context, cacheClearingQueue chan string) (*pssh.SSHServer, error) {
7 host := shared.GetEnv("PGS_HOST", "0.0.0.0")
8 port := shared.GetEnv("PGS_SSH_PORT", "2222")
9 promPort := shared.GetEnv("PGS_PROM_PORT", "9222")
10 logger := cfg.Logger
11
12- ctx, cancel := context.WithCancel(context.Background())
13- defer cancel()
14-
15- cacheClearingQueue := make(chan string, 100)
16 handler := NewUploadAssetHandler(
17 cfg,
18 cacheClearingQueue,
19@@ -68,6 +64,17 @@ func StartSshServer(cfg *PgsConfig, killCh chan error) {
20 },
21 )
22
23+ return server, err
24+}
25+
26+func StartSshServer(cfg *PgsConfig, killCh chan error) {
27+ ctx, cancel := context.WithCancel(context.Background())
28+ defer cancel()
29+
30+ cacheClearingQueue := make(chan string, 100)
31+ logger := cfg.Logger
32+
33+ server, err := createSshServer(cfg, ctx, cacheClearingQueue)
34 if err != nil {
35 logger.Error("failed to create ssh server", "err", err.Error())
36 os.Exit(1)
+107,
-41
1@@ -1,12 +1,14 @@
2 package pgs
3
4 import (
5+ "context"
6 "crypto/ed25519"
7 "crypto/rand"
8 "encoding/pem"
9 "fmt"
10 "io"
11 "log/slog"
12+ "net"
13 "os"
14 "os/exec"
15 "path/filepath"
16@@ -16,6 +18,7 @@ import (
17
18 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
19 "github.com/picosh/pico/pkg/db"
20+ "github.com/picosh/pico/pkg/pssh"
21 "github.com/picosh/pico/pkg/shared"
22 "github.com/picosh/pico/pkg/shared/storage"
23 "github.com/pkg/sftp"
24@@ -23,6 +26,41 @@ import (
25 "golang.org/x/crypto/ssh"
26 )
27
28+func StartSshServerForTesting(cfg *PgsConfig, killCh chan error) *pssh.SSHServer {
29+ ctx, cancel := context.WithCancel(context.Background())
30+ defer func() {
31+ // Cancel is deferred to avoid being called in error path
32+ // It will be called from the cleanup goroutine
33+ _ = cancel
34+ }()
35+
36+ cacheClearingQueue := make(chan string, 100)
37+ logger := cfg.Logger
38+
39+ server, err := createSshServer(cfg, ctx, cacheClearingQueue)
40+ if err != nil {
41+ logger.Error("failed to create ssh server", "err", err.Error())
42+ cancel() // Clean up if server creation fails
43+ return nil
44+ }
45+
46+ logger.Info("Starting SSH server", "addr", server.Config.ListenAddr)
47+ go func() {
48+ if err = server.ListenAndServe(); err != nil {
49+ logger.Error("serve", "err", err.Error())
50+ }
51+ }()
52+
53+ go func() {
54+ // Wait for kill signal and clean up
55+ <-killCh
56+ logger.Info("stopping ssh server")
57+ cancel()
58+ }()
59+
60+ return server
61+}
62+
63 func TestSshServerSftp(t *testing.T) {
64 opts := &slog.HandlerOptions{
65 AddSource: true,
66@@ -43,12 +81,32 @@ func TestSshServerSftp(t *testing.T) {
67 defer func() {
68 _ = pubsub.Close()
69 }()
70+
71+ // Use dynamic port for tests to avoid port conflicts
72+ _ = os.Setenv("PGS_SSH_PORT", "0")
73+
74 cfg := NewPgsConfig(logger, dbpool, st, pubsub)
75 done := make(chan error)
76 prometheus.DefaultRegisterer = prometheus.NewRegistry()
77- go StartSshServer(cfg, done)
78- // Hack to wait for startup
79- time.Sleep(time.Millisecond * 100)
80+
81+ var server *pssh.SSHServer
82+ go func() {
83+ server = StartSshServerForTesting(cfg, done)
84+ }()
85+
86+ // Wait for server to be ready and get the actual listening address
87+ var actualAddr string
88+ for i := 0; i < 100; i++ {
89+ if server != nil && server.Listener != nil {
90+ actualAddr = server.Listener.Addr().String()
91+ break
92+ }
93+ time.Sleep(20 * time.Millisecond)
94+ }
95+
96+ if actualAddr == "" {
97+ t.Fatal("server listener not ready")
98+ }
99
100 user := GenerateUser()
101 // add user's pubkey to the default test account
102@@ -58,7 +116,7 @@ func TestSshServerSftp(t *testing.T) {
103 Key: shared.KeyForKeyText(user.signer.PublicKey()),
104 })
105
106- client, err := user.NewClient()
107+ client, err := user.NewClientAddr(actualAddr)
108 if err != nil {
109 t.Error(err)
110 return
111@@ -96,19 +154,7 @@ func TestSshServerSftp(t *testing.T) {
112 }
113
114 close(done)
115-
116- p, err := os.FindProcess(os.Getpid())
117- if err != nil {
118- t.Fatal(err)
119- return
120- }
121-
122- err = p.Signal(os.Interrupt)
123- if err != nil {
124- t.Fatal(err)
125- return
126- }
127- <-time.After(10 * time.Millisecond)
128+ time.Sleep(100 * time.Millisecond)
129 }
130
131 func TestSshServerRsync(t *testing.T) {
132@@ -131,12 +177,32 @@ func TestSshServerRsync(t *testing.T) {
133 defer func() {
134 _ = pubsub.Close()
135 }()
136+
137+ // Use dynamic port for tests to avoid port conflicts
138+ _ = os.Setenv("PGS_SSH_PORT", "0")
139+
140 cfg := NewPgsConfig(logger, dbpool, st, pubsub)
141 done := make(chan error)
142 prometheus.DefaultRegisterer = prometheus.NewRegistry()
143- go StartSshServer(cfg, done)
144- // Hack to wait for startup
145- time.Sleep(time.Millisecond * 100)
146+
147+ var server *pssh.SSHServer
148+ go func() {
149+ server = StartSshServerForTesting(cfg, done)
150+ }()
151+
152+ // Wait for server to be ready and get the actual listening address
153+ var actualAddr string
154+ for i := 0; i < 100; i++ {
155+ if server != nil && server.Listener != nil {
156+ actualAddr = server.Listener.Addr().String()
157+ break
158+ }
159+ time.Sleep(20 * time.Millisecond)
160+ }
161+
162+ if actualAddr == "" {
163+ t.Fatal("server listener not ready")
164+ }
165
166 user := GenerateUser()
167 key := shared.KeyForKeyText(user.signer.PublicKey())
168@@ -147,7 +213,7 @@ func TestSshServerRsync(t *testing.T) {
169 Key: key,
170 })
171
172- conn, err := user.NewClient()
173+ conn, err := user.NewClientAddr(actualAddr)
174 if err != nil {
175 t.Error(err)
176 return
177@@ -215,13 +281,22 @@ func TestSshServerRsync(t *testing.T) {
178 t.Fatal(err)
179 }
180
181+ // Extract port from actualAddr (format: "0.0.0.0:XXXXX")
182+ _, port, err := net.SplitHostPort(actualAddr)
183+ if err != nil {
184+ t.Fatalf("failed to parse server address: %v", err)
185+ }
186+ // Use localhost for rsync (works regardless of IPv4/IPv6 binding)
187+ host := "localhost"
188+
189 eCmd := fmt.Sprintf(
190- "ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
191+ "ssh -p %s -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
192+ port,
193 keyFile,
194 )
195
196 // copy files
197- cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
198+ cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", host+":/test")
199 result, err := cmd.CombinedOutput()
200 if err != nil {
201 cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
202@@ -270,7 +345,7 @@ func TestSshServerRsync(t *testing.T) {
203 _ = os.RemoveAll(dlName)
204 }()
205 // download files
206- downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
207+ downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, host+":/test/", dlName+"/")
208 result, err = downloadCmd.CombinedOutput()
209 if err != nil {
210 cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
211@@ -306,19 +381,7 @@ func TestSshServerRsync(t *testing.T) {
212 }
213
214 close(done)
215-
216- p, err := os.FindProcess(os.Getpid())
217- if err != nil {
218- t.Fatal(err)
219- return
220- }
221-
222- err = p.Signal(os.Interrupt)
223- if err != nil {
224- t.Fatal(err)
225- return
226- }
227- <-time.After(10 * time.Millisecond)
228+ time.Sleep(100 * time.Millisecond)
229 }
230
231 type UserSSH struct {
232@@ -347,9 +410,7 @@ func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
233 return res
234 }
235
236-func (s UserSSH) NewClient() (*ssh.Client, error) {
237- host := "localhost:2222"
238-
239+func (s UserSSH) NewClientAddr(addr string) (*ssh.Client, error) {
240 config := &ssh.ClientConfig{
241 User: s.username,
242 Auth: []ssh.AuthMethod{
243@@ -358,10 +419,15 @@ func (s UserSSH) NewClient() (*ssh.Client, error) {
244 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
245 }
246
247- client, err := ssh.Dial("tcp", host, config)
248+ client, err := ssh.Dial("tcp", addr, config)
249 return client, err
250 }
251
252+func (s UserSSH) NewClient() (*ssh.Client, error) {
253+ // Default to localhost:2222 for backward compatibility
254+ return s.NewClientAddr("localhost:2222")
255+}
256+
257 func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
258 session, err := client.NewSession()
259 if err != nil {