repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2025-07-04

ssh_test.go

  1package pgs
  2
  3import (
  4	"crypto/ed25519"
  5	"crypto/rand"
  6	"encoding/pem"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"os/exec"
 12	"path/filepath"
 13	"strings"
 14	"testing"
 15	"time"
 16
 17	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 18	"github.com/picosh/pico/pkg/db"
 19	"github.com/picosh/pico/pkg/shared/storage"
 20	"github.com/picosh/utils"
 21	"github.com/pkg/sftp"
 22	"github.com/prometheus/client_golang/prometheus"
 23	"golang.org/x/crypto/ssh"
 24)
 25
 26func TestSshServerSftp(t *testing.T) {
 27	opts := &slog.HandlerOptions{
 28		AddSource: true,
 29		Level:     slog.LevelDebug,
 30	}
 31	logger := slog.New(
 32		slog.NewTextHandler(os.Stdout, opts),
 33	)
 34	slog.SetDefault(logger)
 35	dbpool := pgsdb.NewDBMemory(logger)
 36	// setup test data
 37	dbpool.SetupTestData()
 38	st, err := storage.NewStorageMemory(map[string]map[string]string{})
 39	if err != nil {
 40		panic(err)
 41	}
 42	pubsub := NewPubsubChan()
 43	defer func() {
 44		_ = pubsub.Close()
 45	}()
 46	cfg := NewPgsConfig(logger, dbpool, st, pubsub)
 47	done := make(chan error)
 48	prometheus.DefaultRegisterer = prometheus.NewRegistry()
 49	go StartSshServer(cfg, done)
 50	// Hack to wait for startup
 51	time.Sleep(time.Millisecond * 100)
 52
 53	user := GenerateUser()
 54	// add user's pubkey to the default test account
 55	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
 56		ID:     "nice-pubkey",
 57		UserID: dbpool.Users[0].ID,
 58		Key:    utils.KeyForKeyText(user.signer.PublicKey()),
 59	})
 60
 61	client, err := user.NewClient()
 62	if err != nil {
 63		t.Error(err)
 64		return
 65	}
 66	defer func() {
 67		_ = client.Close()
 68	}()
 69
 70	_, err = WriteFileWithSftp(cfg, client)
 71	if err != nil {
 72		t.Error(err)
 73		return
 74	}
 75
 76	_, err = WriteFilesMultProjectsWithSftp(cfg, client)
 77	if err != nil {
 78		t.Error(err)
 79		return
 80	}
 81
 82	projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
 83	if err != nil {
 84		t.Error(err)
 85		return
 86	}
 87
 88	names := ""
 89	for _, proj := range projects {
 90		names += "_" + proj.Name
 91	}
 92
 93	if names != "_test_mult_mult2" {
 94		t.Errorf("not all projects created: %s", names)
 95		return
 96	}
 97
 98	close(done)
 99
100	p, err := os.FindProcess(os.Getpid())
101	if err != nil {
102		t.Fatal(err)
103		return
104	}
105
106	err = p.Signal(os.Interrupt)
107	if err != nil {
108		t.Fatal(err)
109		return
110	}
111	<-time.After(10 * time.Millisecond)
112}
113
114func TestSshServerRsync(t *testing.T) {
115	opts := &slog.HandlerOptions{
116		AddSource: true,
117		Level:     slog.LevelDebug,
118	}
119	logger := slog.New(
120		slog.NewTextHandler(os.Stdout, opts),
121	)
122	slog.SetDefault(logger)
123	dbpool := pgsdb.NewDBMemory(logger)
124	// setup test data
125	dbpool.SetupTestData()
126	st, err := storage.NewStorageMemory(map[string]map[string]string{})
127	if err != nil {
128		panic(err)
129	}
130	pubsub := NewPubsubChan()
131	defer func() {
132		_ = pubsub.Close()
133	}()
134	cfg := NewPgsConfig(logger, dbpool, st, pubsub)
135	done := make(chan error)
136	prometheus.DefaultRegisterer = prometheus.NewRegistry()
137	go StartSshServer(cfg, done)
138	// Hack to wait for startup
139	time.Sleep(time.Millisecond * 100)
140
141	user := GenerateUser()
142	key := utils.KeyForKeyText(user.signer.PublicKey())
143	// add user's pubkey to the default test account
144	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
145		ID:     "nice-pubkey",
146		UserID: dbpool.Users[0].ID,
147		Key:    key,
148	})
149
150	conn, err := user.NewClient()
151	if err != nil {
152		t.Error(err)
153		return
154	}
155	defer func() {
156		_ = conn.Close()
157	}()
158
159	// open an SFTP session over an existing ssh connection.
160	client, err := sftp.NewClient(conn)
161	if err != nil {
162		cfg.Logger.Error("could not create sftp client", "err", err)
163		panic(err)
164	}
165	defer func() {
166		_ = client.Close()
167	}()
168
169	name, err := os.MkdirTemp("", "rsync-")
170	if err != nil {
171		panic(err)
172	}
173
174	// remove the temporary directory at the end of the program
175	// defer os.RemoveAll(name)
176
177	block := &pem.Block{
178		Type:  "OPENSSH PRIVATE KEY",
179		Bytes: user.privateKey,
180	}
181	keyFile := filepath.Join(name, "id_ed25519")
182	err = os.WriteFile(
183		keyFile,
184		pem.EncodeToMemory(block), 0600,
185	)
186	if err != nil {
187		t.Fatal(err)
188	}
189
190	index := "<!doctype html><html><body>index</body></html>"
191	err = os.WriteFile(
192		filepath.Join(name, "index.html"),
193		[]byte(index), 0666,
194	)
195	if err != nil {
196		t.Fatal(err)
197	}
198
199	about := "<!doctype html><html><body>about</body></html>"
200	aboutFile := filepath.Join(name, "about.html")
201	err = os.WriteFile(
202		aboutFile,
203		[]byte(about), 0666,
204	)
205	if err != nil {
206		t.Fatal(err)
207	}
208
209	contact := "<!doctype html><html><body>contact</body></html>"
210	err = os.WriteFile(
211		filepath.Join(name, "contact.html"),
212		[]byte(contact), 0666,
213	)
214	if err != nil {
215		t.Fatal(err)
216	}
217
218	eCmd := fmt.Sprintf(
219		"ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
220		keyFile,
221	)
222
223	// copy files
224	cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
225	result, err := cmd.CombinedOutput()
226	if err != nil {
227		cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
228		t.Error(err)
229		return
230	}
231
232	// check it's there
233	fi, err := client.Lstat("/test/about.html")
234	if err != nil {
235		cfg.Logger.Error("could not get stat for file", "err", err)
236		t.Error("about.html not found")
237		return
238	}
239	if fi.Size() != 46 {
240		cfg.Logger.Error("about.html wrong size", "size", fi.Size())
241		t.Error("about.html wrong size")
242		return
243	}
244
245	// remove about file
246	_ = os.Remove(aboutFile)
247
248	// copy files with delete
249	delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
250	result, err = delCmd.CombinedOutput()
251	if err != nil {
252		cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
253		t.Error(err)
254		return
255	}
256
257	// check it's not there
258	_, err = client.Lstat("/test/about.html")
259	if err == nil {
260		cfg.Logger.Error("file still exists")
261		t.Error("about.html found")
262		return
263	}
264
265	dlName, err := os.MkdirTemp("", "rsync-download")
266	if err != nil {
267		panic(err)
268	}
269	defer func() {
270		_ = os.RemoveAll(dlName)
271	}()
272	// download files
273	downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
274	result, err = downloadCmd.CombinedOutput()
275	if err != nil {
276		cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
277		t.Error(err)
278		return
279	}
280	// check contents
281	idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
282	if err != nil {
283		cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
284		t.Error(err)
285		return
286	}
287	if string(idx) != index {
288		t.Error("downloaded index.html file does not match original")
289		return
290	}
291	_, err = os.ReadFile(filepath.Join(dlName, "about.html"))
292	if err == nil {
293		cfg.Logger.Error("about file should not exist", "file", "about.html")
294		t.Error(err)
295		return
296	}
297	cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
298	if err != nil {
299		cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
300		t.Error(err)
301		return
302	}
303	if string(cnt) != contact {
304		t.Error("downloaded contact.html file does not match original")
305		return
306	}
307
308	close(done)
309
310	p, err := os.FindProcess(os.Getpid())
311	if err != nil {
312		t.Fatal(err)
313		return
314	}
315
316	err = p.Signal(os.Interrupt)
317	if err != nil {
318		t.Fatal(err)
319		return
320	}
321	<-time.After(10 * time.Millisecond)
322}
323
324type UserSSH struct {
325	username   string
326	signer     ssh.Signer
327	privateKey []byte
328}
329
330func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
331	return &UserSSH{
332		username: username,
333		signer:   signer,
334	}
335}
336
337func (s UserSSH) Public() string {
338	pubkey := s.signer.PublicKey()
339	return string(ssh.MarshalAuthorizedKey(pubkey))
340}
341
342func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
343	res, err := s.Cmd(client, patch, cmd)
344	if err != nil {
345		panic(err)
346	}
347	return res
348}
349
350func (s UserSSH) NewClient() (*ssh.Client, error) {
351	host := "localhost:2222"
352
353	config := &ssh.ClientConfig{
354		User: s.username,
355		Auth: []ssh.AuthMethod{
356			ssh.PublicKeys(s.signer),
357		},
358		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
359	}
360
361	client, err := ssh.Dial("tcp", host, config)
362	return client, err
363}
364
365func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
366	session, err := client.NewSession()
367	if err != nil {
368		return "", err
369	}
370	defer func() {
371		_ = session.Close()
372	}()
373
374	stdinPipe, err := session.StdinPipe()
375	if err != nil {
376		return "", err
377	}
378
379	stdoutPipe, err := session.StdoutPipe()
380	if err != nil {
381		return "", err
382	}
383
384	if err := session.Start(cmd); err != nil {
385		return "", err
386	}
387
388	if patch != nil {
389		_, err = stdinPipe.Write(patch)
390		if err != nil {
391			return "", err
392		}
393	}
394
395	err = stdinPipe.Close()
396	if err != nil {
397		return "", err
398	}
399
400	if err := session.Wait(); err != nil {
401		return "", err
402	}
403
404	buf := new(strings.Builder)
405	_, err = io.Copy(buf, stdoutPipe)
406	if err != nil {
407		return "", err
408	}
409
410	return buf.String(), nil
411}
412
413func GenerateUser() UserSSH {
414	_, userKey, err := ed25519.GenerateKey(rand.Reader)
415	if err != nil {
416		panic(err)
417	}
418
419	b, err := ssh.MarshalPrivateKey(userKey, "")
420	if err != nil {
421		panic(err)
422	}
423
424	userSigner, err := ssh.NewSignerFromKey(userKey)
425	if err != nil {
426		panic(err)
427	}
428
429	return UserSSH{
430		username:   "testuser",
431		signer:     userSigner,
432		privateKey: b.Bytes,
433	}
434}
435
436func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
437	// open an SFTP session over an existing ssh connection.
438	client, err := sftp.NewClient(conn)
439	if err != nil {
440		cfg.Logger.Error("could not create sftp client", "err", err)
441		return nil, err
442	}
443	defer func() {
444		_ = client.Close()
445	}()
446
447	f, err := client.Create("test/hello.txt")
448	if err != nil {
449		cfg.Logger.Error("could not create file", "err", err)
450		return nil, err
451	}
452	if _, err := f.Write([]byte("Hello world!")); err != nil {
453		cfg.Logger.Error("could not write to file", "err", err)
454		return nil, err
455	}
456
457	cfg.Logger.Info("closing", "err", f.Close())
458
459	// check it's there
460	fi, err := client.Lstat("test/hello.txt")
461	if err != nil {
462		cfg.Logger.Error("could not get stat for file", "err", err)
463		return nil, err
464	}
465
466	return &fi, nil
467}
468
469func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
470	// open an SFTP session over an existing ssh connection.
471	client, err := sftp.NewClient(conn)
472	if err != nil {
473		cfg.Logger.Error("could not create sftp client", "err", err)
474		return nil, err
475	}
476	defer func() {
477		_ = client.Close()
478	}()
479
480	f, err := client.Create("mult/hello.txt")
481	if err != nil {
482		cfg.Logger.Error("could not create file", "err", err)
483		return nil, err
484	}
485	if _, err := f.Write([]byte("Hello world!")); err != nil {
486		cfg.Logger.Error("could not write to file", "err", err)
487		return nil, err
488	}
489
490	f, err = client.Create("mult2/hello.txt")
491	if err != nil {
492		cfg.Logger.Error("could not create file", "err", err)
493		return nil, err
494	}
495	if _, err := f.Write([]byte("Hello world!")); err != nil {
496		cfg.Logger.Error("could not write to file", "err", err)
497		return nil, err
498	}
499
500	cfg.Logger.Info("closing", "err", f.Close())
501
502	// check it's there
503	fi, err := client.Lstat("test/hello.txt")
504	if err != nil {
505		cfg.Logger.Error("could not get stat for file", "err", err)
506		return nil, err
507	}
508
509	return &fi, nil
510}