repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2025-04-04

ssh_test.go

  1package pgs
  2
  3import (
  4	"crypto/ed25519"
  5	"crypto/rand"
  6	"encoding/pem"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"os/exec"
 12	"path/filepath"
 13	"strings"
 14	"testing"
 15	"time"
 16
 17	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 18	"github.com/picosh/pico/pkg/db"
 19	"github.com/picosh/pico/pkg/shared/storage"
 20	"github.com/picosh/utils"
 21	"github.com/pkg/sftp"
 22	"github.com/prometheus/client_golang/prometheus"
 23	"golang.org/x/crypto/ssh"
 24)
 25
 26func TestSshServerSftp(t *testing.T) {
 27	opts := &slog.HandlerOptions{
 28		AddSource: true,
 29		Level:     slog.LevelDebug,
 30	}
 31	logger := slog.New(
 32		slog.NewTextHandler(os.Stdout, opts),
 33	)
 34	slog.SetDefault(logger)
 35	dbpool := pgsdb.NewDBMemory(logger)
 36	// setup test data
 37	dbpool.SetupTestData()
 38	st, err := storage.NewStorageMemory(map[string]map[string]string{})
 39	if err != nil {
 40		panic(err)
 41	}
 42	cfg := NewPgsConfig(logger, dbpool, st)
 43	done := make(chan error)
 44	prometheus.DefaultRegisterer = prometheus.NewRegistry()
 45	go StartSshServer(cfg, done)
 46	// Hack to wait for startup
 47	time.Sleep(time.Millisecond * 100)
 48
 49	user := GenerateUser()
 50	// add user's pubkey to the default test account
 51	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
 52		ID:     "nice-pubkey",
 53		UserID: dbpool.Users[0].ID,
 54		Key:    utils.KeyForKeyText(user.signer.PublicKey()),
 55	})
 56
 57	client, err := user.NewClient()
 58	if err != nil {
 59		t.Error(err)
 60		return
 61	}
 62	defer client.Close()
 63
 64	_, err = WriteFileWithSftp(cfg, client)
 65	if err != nil {
 66		t.Error(err)
 67		return
 68	}
 69
 70	_, err = WriteFilesMultProjectsWithSftp(cfg, client)
 71	if err != nil {
 72		t.Error(err)
 73		return
 74	}
 75
 76	projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
 77	if err != nil {
 78		t.Error(err)
 79		return
 80	}
 81
 82	names := ""
 83	for _, proj := range projects {
 84		names += "_" + proj.Name
 85	}
 86
 87	if names != "_test_mult_mult2" {
 88		t.Errorf("not all projects created: %s", names)
 89		return
 90	}
 91
 92	close(done)
 93
 94	p, err := os.FindProcess(os.Getpid())
 95	if err != nil {
 96		t.Fatal(err)
 97		return
 98	}
 99
100	err = p.Signal(os.Interrupt)
101	if err != nil {
102		t.Fatal(err)
103		return
104	}
105	<-time.After(10 * time.Millisecond)
106}
107
108func TestSshServerRsync(t *testing.T) {
109	opts := &slog.HandlerOptions{
110		AddSource: true,
111		Level:     slog.LevelDebug,
112	}
113	logger := slog.New(
114		slog.NewTextHandler(os.Stdout, opts),
115	)
116	slog.SetDefault(logger)
117	dbpool := pgsdb.NewDBMemory(logger)
118	// setup test data
119	dbpool.SetupTestData()
120	st, err := storage.NewStorageMemory(map[string]map[string]string{})
121	if err != nil {
122		panic(err)
123	}
124	cfg := NewPgsConfig(logger, dbpool, st)
125	done := make(chan error)
126	prometheus.DefaultRegisterer = prometheus.NewRegistry()
127	go StartSshServer(cfg, done)
128	// Hack to wait for startup
129	time.Sleep(time.Millisecond * 100)
130
131	user := GenerateUser()
132	key := utils.KeyForKeyText(user.signer.PublicKey())
133	// add user's pubkey to the default test account
134	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
135		ID:     "nice-pubkey",
136		UserID: dbpool.Users[0].ID,
137		Key:    key,
138	})
139
140	conn, err := user.NewClient()
141	if err != nil {
142		t.Error(err)
143		return
144	}
145	defer conn.Close()
146
147	// open an SFTP session over an existing ssh connection.
148	client, err := sftp.NewClient(conn)
149	if err != nil {
150		cfg.Logger.Error("could not create sftp client", "err", err)
151		panic(err)
152	}
153	defer client.Close()
154
155	name, err := os.MkdirTemp("", "rsync-")
156	if err != nil {
157		panic(err)
158	}
159
160	// remove the temporary directory at the end of the program
161	// defer os.RemoveAll(name)
162
163	block := &pem.Block{
164		Type:  "OPENSSH PRIVATE KEY",
165		Bytes: user.privateKey,
166	}
167	keyFile := filepath.Join(name, "id_ed25519")
168	err = os.WriteFile(
169		keyFile,
170		pem.EncodeToMemory(block), 0600,
171	)
172	if err != nil {
173		t.Fatal(err)
174	}
175
176	index := "<!doctype html><html><body>index</body></html>"
177	err = os.WriteFile(
178		filepath.Join(name, "index.html"),
179		[]byte(index), 0666,
180	)
181	if err != nil {
182		t.Fatal(err)
183	}
184
185	about := "<!doctype html><html><body>about</body></html>"
186	aboutFile := filepath.Join(name, "about.html")
187	err = os.WriteFile(
188		aboutFile,
189		[]byte(about), 0666,
190	)
191	if err != nil {
192		t.Fatal(err)
193	}
194
195	contact := "<!doctype html><html><body>contact</body></html>"
196	err = os.WriteFile(
197		filepath.Join(name, "contact.html"),
198		[]byte(contact), 0666,
199	)
200	if err != nil {
201		t.Fatal(err)
202	}
203
204	eCmd := fmt.Sprintf(
205		"ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
206		keyFile,
207	)
208
209	// copy files
210	cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
211	result, err := cmd.CombinedOutput()
212	if err != nil {
213		cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
214		t.Error(err)
215		return
216	}
217
218	// check it's there
219	fi, err := client.Lstat("/test/about.html")
220	if err != nil {
221		cfg.Logger.Error("could not get stat for file", "err", err)
222		t.Error("about.html not found")
223		return
224	}
225	if fi.Size() != 46 {
226		cfg.Logger.Error("about.html wrong size", "size", fi.Size())
227		t.Error("about.html wrong size")
228		return
229	}
230
231	// remove about file
232	os.Remove(aboutFile)
233
234	// copy files with delete
235	delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
236	result, err = delCmd.CombinedOutput()
237	if err != nil {
238		cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
239		t.Error(err)
240		return
241	}
242
243	// check it's not there
244	_, err = client.Lstat("/test/about.html")
245	if err == nil {
246		cfg.Logger.Error("file still exists")
247		t.Error("about.html found")
248		return
249	}
250
251	dlName, err := os.MkdirTemp("", "rsync-download")
252	if err != nil {
253		panic(err)
254	}
255	defer os.RemoveAll(dlName)
256	// download files
257	downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
258	result, err = downloadCmd.CombinedOutput()
259	if err != nil {
260		cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
261		t.Error(err)
262		return
263	}
264	// check contents
265	idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
266	if err != nil {
267		cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
268		t.Error(err)
269		return
270	}
271	if string(idx) != index {
272		t.Error("downloaded index.html file does not match original")
273		return
274	}
275	_, err = os.ReadFile(filepath.Join(dlName, "about.html"))
276	if err == nil {
277		cfg.Logger.Error("about file should not exist", "file", "about.html")
278		t.Error(err)
279		return
280	}
281	cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
282	if err != nil {
283		cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
284		t.Error(err)
285		return
286	}
287	if string(cnt) != contact {
288		t.Error("downloaded contact.html file does not match original")
289		return
290	}
291
292	close(done)
293
294	p, err := os.FindProcess(os.Getpid())
295	if err != nil {
296		t.Fatal(err)
297		return
298	}
299
300	err = p.Signal(os.Interrupt)
301	if err != nil {
302		t.Fatal(err)
303		return
304	}
305	<-time.After(10 * time.Millisecond)
306}
307
308type UserSSH struct {
309	username   string
310	signer     ssh.Signer
311	privateKey []byte
312}
313
314func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
315	return &UserSSH{
316		username: username,
317		signer:   signer,
318	}
319}
320
321func (s UserSSH) Public() string {
322	pubkey := s.signer.PublicKey()
323	return string(ssh.MarshalAuthorizedKey(pubkey))
324}
325
326func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
327	res, err := s.Cmd(client, patch, cmd)
328	if err != nil {
329		panic(err)
330	}
331	return res
332}
333
334func (s UserSSH) NewClient() (*ssh.Client, error) {
335	host := "localhost:2222"
336
337	config := &ssh.ClientConfig{
338		User: s.username,
339		Auth: []ssh.AuthMethod{
340			ssh.PublicKeys(s.signer),
341		},
342		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
343	}
344
345	client, err := ssh.Dial("tcp", host, config)
346	return client, err
347}
348
349func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
350	session, err := client.NewSession()
351	if err != nil {
352		return "", err
353	}
354	defer session.Close()
355
356	stdinPipe, err := session.StdinPipe()
357	if err != nil {
358		return "", err
359	}
360
361	stdoutPipe, err := session.StdoutPipe()
362	if err != nil {
363		return "", err
364	}
365
366	if err := session.Start(cmd); err != nil {
367		return "", err
368	}
369
370	if patch != nil {
371		_, err = stdinPipe.Write(patch)
372		if err != nil {
373			return "", err
374		}
375	}
376
377	stdinPipe.Close()
378
379	if err := session.Wait(); err != nil {
380		return "", err
381	}
382
383	buf := new(strings.Builder)
384	_, err = io.Copy(buf, stdoutPipe)
385	if err != nil {
386		return "", err
387	}
388
389	return buf.String(), nil
390}
391
392func GenerateUser() UserSSH {
393	_, userKey, err := ed25519.GenerateKey(rand.Reader)
394	if err != nil {
395		panic(err)
396	}
397
398	b, err := ssh.MarshalPrivateKey(userKey, "")
399	if err != nil {
400		panic(err)
401	}
402
403	userSigner, err := ssh.NewSignerFromKey(userKey)
404	if err != nil {
405		panic(err)
406	}
407
408	return UserSSH{
409		username:   "testuser",
410		signer:     userSigner,
411		privateKey: b.Bytes,
412	}
413}
414
415func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
416	// open an SFTP session over an existing ssh connection.
417	client, err := sftp.NewClient(conn)
418	if err != nil {
419		cfg.Logger.Error("could not create sftp client", "err", err)
420		return nil, err
421	}
422	defer client.Close()
423
424	f, err := client.Create("test/hello.txt")
425	if err != nil {
426		cfg.Logger.Error("could not create file", "err", err)
427		return nil, err
428	}
429	if _, err := f.Write([]byte("Hello world!")); err != nil {
430		cfg.Logger.Error("could not write to file", "err", err)
431		return nil, err
432	}
433
434	cfg.Logger.Info("closing", "err", f.Close())
435
436	// check it's there
437	fi, err := client.Lstat("test/hello.txt")
438	if err != nil {
439		cfg.Logger.Error("could not get stat for file", "err", err)
440		return nil, err
441	}
442
443	return &fi, nil
444}
445
446func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
447	// open an SFTP session over an existing ssh connection.
448	client, err := sftp.NewClient(conn)
449	if err != nil {
450		cfg.Logger.Error("could not create sftp client", "err", err)
451		return nil, err
452	}
453	defer client.Close()
454
455	f, err := client.Create("mult/hello.txt")
456	if err != nil {
457		cfg.Logger.Error("could not create file", "err", err)
458		return nil, err
459	}
460	if _, err := f.Write([]byte("Hello world!")); err != nil {
461		cfg.Logger.Error("could not write to file", "err", err)
462		return nil, err
463	}
464
465	f, err = client.Create("mult2/hello.txt")
466	if err != nil {
467		cfg.Logger.Error("could not create file", "err", err)
468		return nil, err
469	}
470	if _, err := f.Write([]byte("Hello world!")); err != nil {
471		cfg.Logger.Error("could not write to file", "err", err)
472		return nil, err
473	}
474
475	cfg.Logger.Info("closing", "err", f.Close())
476
477	// check it's there
478	fi, err := client.Lstat("test/hello.txt")
479	if err != nil {
480		cfg.Logger.Error("could not get stat for file", "err", err)
481		return nil, err
482	}
483
484	return &fi, nil
485}