repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / apps / pgs
Eric Bower  ·  2025-05-25

ssh_test.go

  1package pgs
  2
  3import (
  4	"crypto/ed25519"
  5	"crypto/rand"
  6	"encoding/pem"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"os/exec"
 12	"path/filepath"
 13	"strings"
 14	"testing"
 15	"time"
 16
 17	pgsdb "github.com/picosh/pico/pkg/apps/pgs/db"
 18	"github.com/picosh/pico/pkg/db"
 19	"github.com/picosh/pico/pkg/shared/storage"
 20	"github.com/picosh/utils"
 21	"github.com/pkg/sftp"
 22	"github.com/prometheus/client_golang/prometheus"
 23	"golang.org/x/crypto/ssh"
 24)
 25
 26func TestSshServerSftp(t *testing.T) {
 27	opts := &slog.HandlerOptions{
 28		AddSource: true,
 29		Level:     slog.LevelDebug,
 30	}
 31	logger := slog.New(
 32		slog.NewTextHandler(os.Stdout, opts),
 33	)
 34	slog.SetDefault(logger)
 35	dbpool := pgsdb.NewDBMemory(logger)
 36	// setup test data
 37	dbpool.SetupTestData()
 38	st, err := storage.NewStorageMemory(map[string]map[string]string{})
 39	if err != nil {
 40		panic(err)
 41	}
 42	cfg := NewPgsConfig(logger, dbpool, st)
 43	done := make(chan error)
 44	prometheus.DefaultRegisterer = prometheus.NewRegistry()
 45	go StartSshServer(cfg, done)
 46	// Hack to wait for startup
 47	time.Sleep(time.Millisecond * 100)
 48
 49	user := GenerateUser()
 50	// add user's pubkey to the default test account
 51	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
 52		ID:     "nice-pubkey",
 53		UserID: dbpool.Users[0].ID,
 54		Key:    utils.KeyForKeyText(user.signer.PublicKey()),
 55	})
 56
 57	client, err := user.NewClient()
 58	if err != nil {
 59		t.Error(err)
 60		return
 61	}
 62	defer func() {
 63		_ = client.Close()
 64	}()
 65
 66	_, err = WriteFileWithSftp(cfg, client)
 67	if err != nil {
 68		t.Error(err)
 69		return
 70	}
 71
 72	_, err = WriteFilesMultProjectsWithSftp(cfg, client)
 73	if err != nil {
 74		t.Error(err)
 75		return
 76	}
 77
 78	projects, err := dbpool.FindProjectsByUser(dbpool.Users[0].ID)
 79	if err != nil {
 80		t.Error(err)
 81		return
 82	}
 83
 84	names := ""
 85	for _, proj := range projects {
 86		names += "_" + proj.Name
 87	}
 88
 89	if names != "_test_mult_mult2" {
 90		t.Errorf("not all projects created: %s", names)
 91		return
 92	}
 93
 94	close(done)
 95
 96	p, err := os.FindProcess(os.Getpid())
 97	if err != nil {
 98		t.Fatal(err)
 99		return
100	}
101
102	err = p.Signal(os.Interrupt)
103	if err != nil {
104		t.Fatal(err)
105		return
106	}
107	<-time.After(10 * time.Millisecond)
108}
109
110func TestSshServerRsync(t *testing.T) {
111	opts := &slog.HandlerOptions{
112		AddSource: true,
113		Level:     slog.LevelDebug,
114	}
115	logger := slog.New(
116		slog.NewTextHandler(os.Stdout, opts),
117	)
118	slog.SetDefault(logger)
119	dbpool := pgsdb.NewDBMemory(logger)
120	// setup test data
121	dbpool.SetupTestData()
122	st, err := storage.NewStorageMemory(map[string]map[string]string{})
123	if err != nil {
124		panic(err)
125	}
126	cfg := NewPgsConfig(logger, dbpool, st)
127	done := make(chan error)
128	prometheus.DefaultRegisterer = prometheus.NewRegistry()
129	go StartSshServer(cfg, done)
130	// Hack to wait for startup
131	time.Sleep(time.Millisecond * 100)
132
133	user := GenerateUser()
134	key := utils.KeyForKeyText(user.signer.PublicKey())
135	// add user's pubkey to the default test account
136	dbpool.Pubkeys = append(dbpool.Pubkeys, &db.PublicKey{
137		ID:     "nice-pubkey",
138		UserID: dbpool.Users[0].ID,
139		Key:    key,
140	})
141
142	conn, err := user.NewClient()
143	if err != nil {
144		t.Error(err)
145		return
146	}
147	defer func() {
148		_ = conn.Close()
149	}()
150
151	// open an SFTP session over an existing ssh connection.
152	client, err := sftp.NewClient(conn)
153	if err != nil {
154		cfg.Logger.Error("could not create sftp client", "err", err)
155		panic(err)
156	}
157	defer func() {
158		_ = client.Close()
159	}()
160
161	name, err := os.MkdirTemp("", "rsync-")
162	if err != nil {
163		panic(err)
164	}
165
166	// remove the temporary directory at the end of the program
167	// defer os.RemoveAll(name)
168
169	block := &pem.Block{
170		Type:  "OPENSSH PRIVATE KEY",
171		Bytes: user.privateKey,
172	}
173	keyFile := filepath.Join(name, "id_ed25519")
174	err = os.WriteFile(
175		keyFile,
176		pem.EncodeToMemory(block), 0600,
177	)
178	if err != nil {
179		t.Fatal(err)
180	}
181
182	index := "<!doctype html><html><body>index</body></html>"
183	err = os.WriteFile(
184		filepath.Join(name, "index.html"),
185		[]byte(index), 0666,
186	)
187	if err != nil {
188		t.Fatal(err)
189	}
190
191	about := "<!doctype html><html><body>about</body></html>"
192	aboutFile := filepath.Join(name, "about.html")
193	err = os.WriteFile(
194		aboutFile,
195		[]byte(about), 0666,
196	)
197	if err != nil {
198		t.Fatal(err)
199	}
200
201	contact := "<!doctype html><html><body>contact</body></html>"
202	err = os.WriteFile(
203		filepath.Join(name, "contact.html"),
204		[]byte(contact), 0666,
205	)
206	if err != nil {
207		t.Fatal(err)
208	}
209
210	eCmd := fmt.Sprintf(
211		"ssh -p 2222 -o IdentitiesOnly=yes -i %s -o StrictHostKeyChecking=no",
212		keyFile,
213	)
214
215	// copy files
216	cmd := exec.Command("rsync", "-rv", "-e", eCmd, name+"/", "localhost:/test")
217	result, err := cmd.CombinedOutput()
218	if err != nil {
219		cfg.Logger.Error("cannot upload", "err", err, "result", string(result))
220		t.Error(err)
221		return
222	}
223
224	// check it's there
225	fi, err := client.Lstat("/test/about.html")
226	if err != nil {
227		cfg.Logger.Error("could not get stat for file", "err", err)
228		t.Error("about.html not found")
229		return
230	}
231	if fi.Size() != 46 {
232		cfg.Logger.Error("about.html wrong size", "size", fi.Size())
233		t.Error("about.html wrong size")
234		return
235	}
236
237	// remove about file
238	_ = os.Remove(aboutFile)
239
240	// copy files with delete
241	delCmd := exec.Command("rsync", "-rv", "--delete", "-e", eCmd, name+"/", "localhost:/test")
242	result, err = delCmd.CombinedOutput()
243	if err != nil {
244		cfg.Logger.Error("cannot upload with delete", "err", err, "result", string(result))
245		t.Error(err)
246		return
247	}
248
249	// check it's not there
250	_, err = client.Lstat("/test/about.html")
251	if err == nil {
252		cfg.Logger.Error("file still exists")
253		t.Error("about.html found")
254		return
255	}
256
257	dlName, err := os.MkdirTemp("", "rsync-download")
258	if err != nil {
259		panic(err)
260	}
261	defer func() {
262		_ = os.RemoveAll(dlName)
263	}()
264	// download files
265	downloadCmd := exec.Command("rsync", "-rvvv", "-e", eCmd, "localhost:/test/", dlName+"/")
266	result, err = downloadCmd.CombinedOutput()
267	if err != nil {
268		cfg.Logger.Error("cannot download files", "err", err, "result", string(result))
269		t.Error(err)
270		return
271	}
272	// check contents
273	idx, err := os.ReadFile(filepath.Join(dlName, "index.html"))
274	if err != nil {
275		cfg.Logger.Error("cannot open file", "file", "index.html", "err", err)
276		t.Error(err)
277		return
278	}
279	if string(idx) != index {
280		t.Error("downloaded index.html file does not match original")
281		return
282	}
283	_, err = os.ReadFile(filepath.Join(dlName, "about.html"))
284	if err == nil {
285		cfg.Logger.Error("about file should not exist", "file", "about.html")
286		t.Error(err)
287		return
288	}
289	cnt, err := os.ReadFile(filepath.Join(dlName, "contact.html"))
290	if err != nil {
291		cfg.Logger.Error("cannot open file", "file", "contact.html", "err", err)
292		t.Error(err)
293		return
294	}
295	if string(cnt) != contact {
296		t.Error("downloaded contact.html file does not match original")
297		return
298	}
299
300	close(done)
301
302	p, err := os.FindProcess(os.Getpid())
303	if err != nil {
304		t.Fatal(err)
305		return
306	}
307
308	err = p.Signal(os.Interrupt)
309	if err != nil {
310		t.Fatal(err)
311		return
312	}
313	<-time.After(10 * time.Millisecond)
314}
315
316type UserSSH struct {
317	username   string
318	signer     ssh.Signer
319	privateKey []byte
320}
321
322func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
323	return &UserSSH{
324		username: username,
325		signer:   signer,
326	}
327}
328
329func (s UserSSH) Public() string {
330	pubkey := s.signer.PublicKey()
331	return string(ssh.MarshalAuthorizedKey(pubkey))
332}
333
334func (s UserSSH) MustCmd(client *ssh.Client, patch []byte, cmd string) string {
335	res, err := s.Cmd(client, patch, cmd)
336	if err != nil {
337		panic(err)
338	}
339	return res
340}
341
342func (s UserSSH) NewClient() (*ssh.Client, error) {
343	host := "localhost:2222"
344
345	config := &ssh.ClientConfig{
346		User: s.username,
347		Auth: []ssh.AuthMethod{
348			ssh.PublicKeys(s.signer),
349		},
350		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
351	}
352
353	client, err := ssh.Dial("tcp", host, config)
354	return client, err
355}
356
357func (s UserSSH) Cmd(client *ssh.Client, patch []byte, cmd string) (string, error) {
358	session, err := client.NewSession()
359	if err != nil {
360		return "", err
361	}
362	defer func() {
363		_ = session.Close()
364	}()
365
366	stdinPipe, err := session.StdinPipe()
367	if err != nil {
368		return "", err
369	}
370
371	stdoutPipe, err := session.StdoutPipe()
372	if err != nil {
373		return "", err
374	}
375
376	if err := session.Start(cmd); err != nil {
377		return "", err
378	}
379
380	if patch != nil {
381		_, err = stdinPipe.Write(patch)
382		if err != nil {
383			return "", err
384		}
385	}
386
387	err = stdinPipe.Close()
388	if err != nil {
389		return "", err
390	}
391
392	if err := session.Wait(); err != nil {
393		return "", err
394	}
395
396	buf := new(strings.Builder)
397	_, err = io.Copy(buf, stdoutPipe)
398	if err != nil {
399		return "", err
400	}
401
402	return buf.String(), nil
403}
404
405func GenerateUser() UserSSH {
406	_, userKey, err := ed25519.GenerateKey(rand.Reader)
407	if err != nil {
408		panic(err)
409	}
410
411	b, err := ssh.MarshalPrivateKey(userKey, "")
412	if err != nil {
413		panic(err)
414	}
415
416	userSigner, err := ssh.NewSignerFromKey(userKey)
417	if err != nil {
418		panic(err)
419	}
420
421	return UserSSH{
422		username:   "testuser",
423		signer:     userSigner,
424		privateKey: b.Bytes,
425	}
426}
427
428func WriteFileWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
429	// open an SFTP session over an existing ssh connection.
430	client, err := sftp.NewClient(conn)
431	if err != nil {
432		cfg.Logger.Error("could not create sftp client", "err", err)
433		return nil, err
434	}
435	defer func() {
436		_ = client.Close()
437	}()
438
439	f, err := client.Create("test/hello.txt")
440	if err != nil {
441		cfg.Logger.Error("could not create file", "err", err)
442		return nil, err
443	}
444	if _, err := f.Write([]byte("Hello world!")); err != nil {
445		cfg.Logger.Error("could not write to file", "err", err)
446		return nil, err
447	}
448
449	cfg.Logger.Info("closing", "err", f.Close())
450
451	// check it's there
452	fi, err := client.Lstat("test/hello.txt")
453	if err != nil {
454		cfg.Logger.Error("could not get stat for file", "err", err)
455		return nil, err
456	}
457
458	return &fi, nil
459}
460
461func WriteFilesMultProjectsWithSftp(cfg *PgsConfig, conn *ssh.Client) (*os.FileInfo, error) {
462	// open an SFTP session over an existing ssh connection.
463	client, err := sftp.NewClient(conn)
464	if err != nil {
465		cfg.Logger.Error("could not create sftp client", "err", err)
466		return nil, err
467	}
468	defer func() {
469		_ = client.Close()
470	}()
471
472	f, err := client.Create("mult/hello.txt")
473	if err != nil {
474		cfg.Logger.Error("could not create file", "err", err)
475		return nil, err
476	}
477	if _, err := f.Write([]byte("Hello world!")); err != nil {
478		cfg.Logger.Error("could not write to file", "err", err)
479		return nil, err
480	}
481
482	f, err = client.Create("mult2/hello.txt")
483	if err != nil {
484		cfg.Logger.Error("could not create file", "err", err)
485		return nil, err
486	}
487	if _, err := f.Write([]byte("Hello world!")); err != nil {
488		cfg.Logger.Error("could not write to file", "err", err)
489		return nil, err
490	}
491
492	cfg.Logger.Info("closing", "err", f.Close())
493
494	// check it's there
495	fi, err := client.Lstat("test/hello.txt")
496	if err != nil {
497		cfg.Logger.Error("could not get stat for file", "err", err)
498		return nil, err
499	}
500
501	return &fi, nil
502}