Eric Bower
·
2025-05-25
ssh_test.go
1package pgs
2
3import (
4 "crypto/ed25519"
5 "crypto/rand"
6 "encoding/pem"
7 "fmt"
8 "io"
9 "log/slog"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "strings"
14 "testing"
15 "time"
16
17 pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
18 "github.com/picosh/pico/pkg/db"
19 "github.com/picosh/pico/pkg/shared/storage"
20 "github.com/picosh/utils"
21 "github.com/pkg/sftp"
22 "github.com/prometheus/client_golang/prometheus"
23 "golang.org/x/crypto/ssh"
24)
25
26func TestSshServerSftp(t *testing.T) {
27 opts := &slog.HandlerOptions{
28 AddSource: true,
29 Level: slog.LevelDebug,
30 }
31 logger := slog.New(
32 slog.NewTextHandler(os.Stdout, opts),
33 )
34 slog.SetDefault(logger)
35 dbpool := pgsdb.NewDBMemory(logger)
36 // setup test data
37 dbpool.SetupTestData()
38 st, err := storage.NewStorageMemory(map[string]map[string]string{})
39 if err != nil {
40 panic(err)
41 }
42 cfg := NewPgsConfig(logger, dbpool, st)
43 done := make(chan error)
44 prometheus.DefaultRegisterer = prometheus.NewRegistry()
45 go StartSshServer(cfg, done)
46 // Hack to wait for startup
47 time.Sleep(time.Millisecond * 100)
48
49 user := GenerateUser()
50 // add user's pubkey to the default test account
51 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
52 ID: "nice-pubkey",
53 UserID: dbpool.Users[0].ID,
54 Key: utils.KeyForKeyText(user.signer.PublicKey()),
55 })
56
57 client, err := user.NewClient()
58 if err != nil {
59 t.Error(err)
60 return
61 }
62 defer func() {
63 _ = client.Close()
64 }()
65
66 _, err = WriteFileWithSftp(cfg, client)
67 if err != nil {
68 t.Error(err)
69 return
70 }
71
72 _, err = WriteFilesMultProjectsWithSftp(cfg, client)
73 if err != nil {
74 t.Error(err)
75 return
76 }
77
78 projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
79 if err != nil {
80 t.Error(err)
81 return
82 }
83
84 names := ""
85 for _, proj := range projects {
86 names += "_" + proj.Name
87 }
88
89 if names != "_test_mult_mult2" {
90 t.Errorf("not all projects created: %s", names)
91 return
92 }
93
94 close(done)
95
96 p, err := os.FindProcess(os.Getpid())
97 if err != nil {
98 t.Fatal(err)
99 return
100 }
101
102 err = p.Signal(os.Interrupt)
103 if err != nil {
104 t.Fatal(err)
105 return
106 }
107 <-time.After(10 * time.Millisecond)
108}
109
110func TestSshServerRsync(t *testing.T) {
111 opts := &slog.HandlerOptions{
112 AddSource: true,
113 Level: slog.LevelDebug,
114 }
115 logger := slog.New(
116 slog.NewTextHandler(os.Stdout, opts),
117 )
118 slog.SetDefault(logger)
119 dbpool := pgsdb.NewDBMemory(logger)
120 // setup test data
121 dbpool.SetupTestData()
122 st, err := storage.NewStorageMemory(map[string]map[string]string{})
123 if err != nil {
124 panic(err)
125 }
126 cfg := NewPgsConfig(logger, dbpool, st)
127 done := make(chan error)
128 prometheus.DefaultRegisterer = prometheus.NewRegistry()
129 go StartSshServer(cfg, done)
130 // Hack to wait for startup
131 time.Sleep(time.Millisecond * 100)
132
133 user := GenerateUser()
134 key := utils.KeyForKeyText(user.signer.PublicKey())
135 // add user's pubkey to the default test account
136 dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
137 ID: "nice-pubkey",
138 UserID: dbpool.Users[0].ID,
139 Key: key,
140 })
141
142 conn, err := user.NewClient()
143 if err != nil {
144 t.Error(err)
145 return
146 }
147 defer func() {
148 _ = conn.Close()
149 }()
150
151 // open an SFTP session over an existing ssh connection.
152 client, err := sftp.NewClient(conn)
153 if err != nil {
154 cfg.Logger.Error("could not create sftp client", "err", err)
155 panic(err)
156 }
157 defer func() {
158 _ = client.Close()
159 }()
160
161 name, err := os.MkdirTemp("", "rsync-")
162 if err != nil {
163 panic(err)
164 }
165
166 // remove the temporary directory at the end of the program
167 // defer os.RemoveAll(name)
168
169 block := &pem.Block{
170 Type: "OPENSSH PRIVATE KEY",
171 Bytes: user.privateKey,
172 }
173 keyFile := filepath.Join(name, "id_ed25519")
174 err = os.WriteFile(
175 keyFile,
176 pem.EncodeToMemory(block), 0600,
177 )
178 if err != nil {
179 t.Fatal(err)
180 }
181
182 index := "<!doctype html><html><body>index</body></html>"
183 err = os.WriteFile(
184 filepath.Join(name, "index.html"),
185 []byte(index), 0666,
186 )
187 if err != nil {
188 t.Fatal(err)
189 }
190
191 about := "<!doctype html><html><body>about</body></html>"
192 aboutFile := filepath.Join(name, "about.html")
193 err = os.WriteFile(
194 aboutFile,
195 []byte(about), 0666,
196 )
197 if err != nil {
198 t.Fatal(err)
199 }
200
201 contact := "<!doctype html><html><body>contact</body></html>"
202 err = os.WriteFile(
203 filepath.Join(name, "contact.html"),
204 []byte(contact), 0666,
205 )
206 if err != nil {
207 t.Fatal(err)
208 }
209
210 eCmd := fmt.Sprintf(
211 "ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
212 keyFile,
213 )
214
215 // copy files
216 cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
217 result, err := cmd.CombinedOutput()
218 if err != nil {
219 cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
220 t.Error(err)
221 return
222 }
223
224 // check it's there
225 fi, err := client.Lstat("/test/about.html")
226 if err != nil {
227 cfg.Logger.Error("could not get stat for file", "err", err)
228 t.Error("about.html not found")
229 return
230 }
231 if fi.Size() != 46 {
232 cfg.Logger.Error("about.html wrong size", "size", fi.Size())
233 t.Error("about.html wrong size")
234 return
235 }
236
237 // remove about file
238 _ = os.Remove(aboutFile)
239
240 // copy files with delete
241 delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
242 result, err = delCmd.CombinedOutput()
243 if err != nil {
244 cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
245 t.Error(err)
246 return
247 }
248
249 // check it's not there
250 _, err = client.Lstat("/test/about.html")
251 if err == nil {
252 cfg.Logger.Error("file still exists")
253 t.Error("about.html found")
254 return
255 }
256
257 dlName, err := os.MkdirTemp("", "rsync-download")
258 if err != nil {
259 panic(err)
260 }
261 defer func() {
262 _ = os.RemoveAll(dlName)
263 }()
264 // download files
265 downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
266 result, err = downloadCmd.CombinedOutput()
267 if err != nil {
268 cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
269 t.Error(err)
270 return
271 }
272 // check contents
273 idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
274 if err != nil {
275 cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
276 t.Error(err)
277 return
278 }
279 if string(idx) != index {
280 t.Error("downloaded index.html file does not match original")
281 return
282 }
283 _, err = os.ReadFile(filepath.Join(dlName, "about.html"))
284 if err == nil {
285 cfg.Logger.Error("about file should not exist", "file", "about.html")
286 t.Error(err)
287 return
288 }
289 cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
290 if err != nil {
291 cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
292 t.Error(err)
293 return
294 }
295 if string(cnt) != contact {
296 t.Error("downloaded contact.html file does not match original")
297 return
298 }
299
300 close(done)
301
302 p, err := os.FindProcess(os.Getpid())
303 if err != nil {
304 t.Fatal(err)
305 return
306 }
307
308 err = p.Signal(os.Interrupt)
309 if err != nil {
310 t.Fatal(err)
311 return
312 }
313 <-time.After(10 * time.Millisecond)
314}
315
316type UserSSH struct {
317 username string
318 signer ssh.Signer
319 privateKey []byte
320}
321
322func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
323 return &UserSSH{
324 username: username,
325 signer: signer,
326 }
327}
328
329func (s UserSSH) Public() string {
330 pubkey := s.signer.PublicKey()
331 return string(ssh.MarshalAuthorizedKey(pubkey))
332}
333
334func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
335 res, err := s.Cmd(client, patch, cmd)
336 if err != nil {
337 panic(err)
338 }
339 return res
340}
341
342func (s UserSSH) NewClient() (*ssh.Client, error) {
343 host := "localhost:2222"
344
345 config := &ssh.ClientConfig{
346 User: s.username,
347 Auth: []ssh.AuthMethod{
348 ssh.PublicKeys(s.signer),
349 },
350 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
351 }
352
353 client, err := ssh.Dial("tcp", host, config)
354 return client, err
355}
356
357func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
358 session, err := client.NewSession()
359 if err != nil {
360 return "", err
361 }
362 defer func() {
363 _ = session.Close()
364 }()
365
366 stdinPipe, err := session.StdinPipe()
367 if err != nil {
368 return "", err
369 }
370
371 stdoutPipe, err := session.StdoutPipe()
372 if err != nil {
373 return "", err
374 }
375
376 if err := session.Start(cmd); err != nil {
377 return "", err
378 }
379
380 if patch != nil {
381 _, err = stdinPipe.Write(patch)
382 if err != nil {
383 return "", err
384 }
385 }
386
387 err = stdinPipe.Close()
388 if err != nil {
389 return "", err
390 }
391
392 if err := session.Wait(); err != nil {
393 return "", err
394 }
395
396 buf := new(strings.Builder)
397 _, err = io.Copy(buf, stdoutPipe)
398 if err != nil {
399 return "", err
400 }
401
402 return buf.String(), nil
403}
404
405func GenerateUser() UserSSH {
406 _, userKey, err := ed25519.GenerateKey(rand.Reader)
407 if err != nil {
408 panic(err)
409 }
410
411 b, err := ssh.MarshalPrivateKey(userKey, "")
412 if err != nil {
413 panic(err)
414 }
415
416 userSigner, err := ssh.NewSignerFromKey(userKey)
417 if err != nil {
418 panic(err)
419 }
420
421 return UserSSH{
422 username: "testuser",
423 signer: userSigner,
424 privateKey: b.Bytes,
425 }
426}
427
428func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
429 // open an SFTP session over an existing ssh connection.
430 client, err := sftp.NewClient(conn)
431 if err != nil {
432 cfg.Logger.Error("could not create sftp client", "err", err)
433 return nil, err
434 }
435 defer func() {
436 _ = client.Close()
437 }()
438
439 f, err := client.Create("test/hello.txt")
440 if err != nil {
441 cfg.Logger.Error("could not create file", "err", err)
442 return nil, err
443 }
444 if _, err := f.Write([]byte("Hello world!")); err != nil {
445 cfg.Logger.Error("could not write to file", "err", err)
446 return nil, err
447 }
448
449 cfg.Logger.Info("closing", "err", f.Close())
450
451 // check it's there
452 fi, err := client.Lstat("test/hello.txt")
453 if err != nil {
454 cfg.Logger.Error("could not get stat for file", "err", err)
455 return nil, err
456 }
457
458 return &fi, nil
459}
460
461func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
462 // open an SFTP session over an existing ssh connection.
463 client, err := sftp.NewClient(conn)
464 if err != nil {
465 cfg.Logger.Error("could not create sftp client", "err", err)
466 return nil, err
467 }
468 defer func() {
469 _ = client.Close()
470 }()
471
472 f, err := client.Create("mult/hello.txt")
473 if err != nil {
474 cfg.Logger.Error("could not create file", "err", err)
475 return nil, err
476 }
477 if _, err := f.Write([]byte("Hello world!")); err != nil {
478 cfg.Logger.Error("could not write to file", "err", err)
479 return nil, err
480 }
481
482 f, err = client.Create("mult2/hello.txt")
483 if err != nil {
484 cfg.Logger.Error("could not create file", "err", err)
485 return nil, err
486 }
487 if _, err := f.Write([]byte("Hello world!")); err != nil {
488 cfg.Logger.Error("could not write to file", "err", err)
489 return nil, err
490 }
491
492 cfg.Logger.Info("closing", "err", f.Close())
493
494 // check it's there
495 fi, err := client.Lstat("test/hello.txt")
496 if err != nil {
497 cfg.Logger.Error("could not get stat for file", "err", err)
498 return nil, err
499 }
500
501 return &fi, nil
502}