repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

commit
47413a9
parent
1d500f9
author
Eric Bower
date
2025-12-25 23:32:29 -0500 EST
chore(pipe): added tests
1 files changed,  +661, -0
A pkg/apps/pipe/ssh_test.go
+661, -0
  1@@ -0,0 +1,661 @@
  2+package pipe
  3+
  4+import (
  5+	"context"
  6+	"crypto/ed25519"
  7+	"crypto/rand"
  8+	"fmt"
  9+	"io"
 10+	"log/slog"
 11+	"os"
 12+	"strings"
 13+	"testing"
 14+	"time"
 15+
 16+	"github.com/antoniomika/syncmap"
 17+	"github.com/picosh/pico/pkg/db"
 18+	"github.com/picosh/pico/pkg/db/stub"
 19+	"github.com/picosh/pico/pkg/pssh"
 20+	"github.com/picosh/pico/pkg/shared"
 21+	psub "github.com/picosh/pubsub"
 22+	"github.com/picosh/utils"
 23+	"github.com/prometheus/client_golang/prometheus"
 24+	"golang.org/x/crypto/ssh"
 25+)
 26+
 27+type TestDB struct {
 28+	*stub.StubDB
 29+	Users    []*db.User
 30+	Pubkeys  []*db.PublicKey
 31+	Features []*db.FeatureFlag
 32+}
 33+
 34+func NewTestDB(logger *slog.Logger) *TestDB {
 35+	return &TestDB{
 36+		StubDB: stub.NewStubDB(logger),
 37+	}
 38+}
 39+
 40+func (t *TestDB) FindUserByPubkey(key string) (*db.User, error) {
 41+	for _, pk := range t.Pubkeys {
 42+		if pk.Key == key {
 43+			return t.FindUser(pk.UserID)
 44+		}
 45+	}
 46+	return nil, fmt.Errorf("user not found for pubkey")
 47+}
 48+
 49+func (t *TestDB) FindUser(userID string) (*db.User, error) {
 50+	for _, user := range t.Users {
 51+		if user.ID == userID {
 52+			return user, nil
 53+		}
 54+	}
 55+	return nil, fmt.Errorf("user not found")
 56+}
 57+
 58+func (t *TestDB) FindUserByName(name string) (*db.User, error) {
 59+	for _, user := range t.Users {
 60+		if user.Name == name {
 61+			return user, nil
 62+		}
 63+	}
 64+	return nil, fmt.Errorf("user not found")
 65+}
 66+
 67+func (t *TestDB) FindFeature(userID, name string) (*db.FeatureFlag, error) {
 68+	for _, ff := range t.Features {
 69+		if ff.UserID == userID && ff.Name == name {
 70+			return ff, nil
 71+		}
 72+	}
 73+	return nil, fmt.Errorf("feature not found")
 74+}
 75+
 76+func (t *TestDB) HasFeatureByUser(userID string, feature string) bool {
 77+	ff, err := t.FindFeature(userID, feature)
 78+	if err != nil {
 79+		return false
 80+	}
 81+	return ff.IsValid()
 82+}
 83+
 84+func (t *TestDB) InsertAccessLog(_ *db.AccessLog) error {
 85+	return nil
 86+}
 87+
 88+func (t *TestDB) Close() error {
 89+	return nil
 90+}
 91+
 92+func (t *TestDB) AddUser(user *db.User) {
 93+	t.Users = append(t.Users, user)
 94+}
 95+
 96+func (t *TestDB) AddPubkey(pubkey *db.PublicKey) {
 97+	t.Pubkeys = append(t.Pubkeys, pubkey)
 98+}
 99+
100+type TestSSHServer struct {
101+	Cfg    *shared.ConfigSite
102+	DBPool *TestDB
103+	Cancel context.CancelFunc
104+}
105+
106+func NewTestSSHServer(t *testing.T) *TestSSHServer {
107+	t.Helper()
108+
109+	opts := &slog.HandlerOptions{
110+		AddSource: true,
111+		Level:     slog.LevelDebug,
112+	}
113+	logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
114+
115+	dbpool := NewTestDB(logger)
116+
117+	cfg := &shared.ConfigSite{
118+		Domain:       "pipe.test",
119+		Port:         "2222",
120+		PortOverride: "2222",
121+		Protocol:     "ssh",
122+		Logger:       logger,
123+		Space:        "pipe",
124+	}
125+
126+	ctx, cancel := context.WithCancel(context.Background())
127+
128+	pubsub := psub.NewMulticast(logger)
129+	handler := &CliHandler{
130+		Logger:  logger,
131+		DBPool:  dbpool,
132+		PubSub:  pubsub,
133+		Cfg:     cfg,
134+		Waiters: syncmap.New[string, []string](),
135+		Access:  syncmap.New[string, []string](),
136+	}
137+
138+	sshAuth := shared.NewSshAuthHandler(dbpool, logger, "pipe")
139+
140+	prometheus.DefaultRegisterer = prometheus.NewRegistry()
141+
142+	server, err := pssh.NewSSHServerWithConfig(
143+		ctx,
144+		logger,
145+		"pipe-ssh-test",
146+		"localhost",
147+		"2222",
148+		"9222",
149+		"../../ssh_data/term_info_ed25519",
150+		func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
151+			perms, _ := sshAuth.PubkeyAuthHandler(conn, key)
152+			if perms == nil {
153+				perms = &ssh.Permissions{
154+					Extensions: map[string]string{
155+						"pubkey": utils.KeyForKeyText(key),
156+					},
157+				}
158+			}
159+			return perms, nil
160+		},
161+		[]pssh.SSHServerMiddleware{
162+			Middleware(handler),
163+			pssh.LogMiddleware(handler, dbpool),
164+		},
165+		nil,
166+		nil,
167+	)
168+
169+	if err != nil {
170+		t.Fatalf("failed to create ssh server: %v", err)
171+	}
172+
173+	go func() {
174+		if err := server.ListenAndServe(); err != nil {
175+			logger.Error("serve", "err", err.Error())
176+		}
177+	}()
178+
179+	time.Sleep(100 * time.Millisecond)
180+
181+	return &TestSSHServer{
182+		Cfg:    cfg,
183+		DBPool: dbpool,
184+		Cancel: cancel,
185+	}
186+}
187+
188+func (s *TestSSHServer) Shutdown() {
189+	s.Cancel()
190+	time.Sleep(10 * time.Millisecond)
191+}
192+
193+type UserSSH struct {
194+	username   string
195+	signer     ssh.Signer
196+	privateKey []byte
197+}
198+
199+func GenerateUser(username string) UserSSH {
200+	_, userKey, err := ed25519.GenerateKey(rand.Reader)
201+	if err != nil {
202+		panic(err)
203+	}
204+
205+	b, err := ssh.MarshalPrivateKey(userKey, "")
206+	if err != nil {
207+		panic(err)
208+	}
209+
210+	userSigner, err := ssh.NewSignerFromKey(userKey)
211+	if err != nil {
212+		panic(err)
213+	}
214+
215+	return UserSSH{
216+		username:   username,
217+		signer:     userSigner,
218+		privateKey: b.Bytes,
219+	}
220+}
221+
222+func (u UserSSH) PublicKey() string {
223+	return utils.KeyForKeyText(u.signer.PublicKey())
224+}
225+
226+func (u UserSSH) NewClient() (*ssh.Client, error) {
227+	config := &ssh.ClientConfig{
228+		User: u.username,
229+		Auth: []ssh.AuthMethod{
230+			ssh.PublicKeys(u.signer),
231+		},
232+		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
233+	}
234+
235+	return ssh.Dial("tcp", "localhost:2222", config)
236+}
237+
238+func (u UserSSH) RunCommand(client *ssh.Client, cmd string) (string, error) {
239+	session, err := client.NewSession()
240+	if err != nil {
241+		return "", err
242+	}
243+	defer func() { _ = session.Close() }()
244+
245+	stdoutPipe, err := session.StdoutPipe()
246+	if err != nil {
247+		return "", err
248+	}
249+
250+	stderrPipe, err := session.StderrPipe()
251+	if err != nil {
252+		return "", err
253+	}
254+
255+	if err := session.Start(cmd); err != nil {
256+		return "", err
257+	}
258+
259+	stdout := new(strings.Builder)
260+	stderr := new(strings.Builder)
261+	_, _ = io.Copy(stdout, stdoutPipe)
262+	_, _ = io.Copy(stderr, stderrPipe)
263+
264+	_ = session.Wait()
265+	return stdout.String() + stderr.String(), nil
266+}
267+
268+func (u UserSSH) RunCommandWithStdin(client *ssh.Client, cmd string, stdin string) (string, error) {
269+	session, err := client.NewSession()
270+	if err != nil {
271+		return "", err
272+	}
273+	defer func() { _ = session.Close() }()
274+
275+	stdinPipe, err := session.StdinPipe()
276+	if err != nil {
277+		return "", err
278+	}
279+
280+	stdoutPipe, err := session.StdoutPipe()
281+	if err != nil {
282+		return "", err
283+	}
284+
285+	if err := session.Start(cmd); err != nil {
286+		return "", err
287+	}
288+
289+	_, err = stdinPipe.Write([]byte(stdin))
290+	if err != nil {
291+		return "", err
292+	}
293+	_ = stdinPipe.Close()
294+
295+	buf := new(strings.Builder)
296+	_, err = io.Copy(buf, stdoutPipe)
297+	if err != nil {
298+		return "", err
299+	}
300+
301+	_ = session.Wait()
302+	return buf.String(), nil
303+}
304+
305+func RegisterUserWithServer(server *TestSSHServer, user UserSSH) {
306+	dbUser := &db.User{
307+		ID:   user.username + "-id",
308+		Name: user.username,
309+	}
310+	server.DBPool.AddUser(dbUser)
311+	server.DBPool.AddPubkey(&db.PublicKey{
312+		ID:     user.username + "-pubkey-id",
313+		UserID: dbUser.ID,
314+		Key:    user.PublicKey(),
315+	})
316+}
317+
318+func TestLs_UnauthenticatedUserDenied(t *testing.T) {
319+	server := NewTestSSHServer(t)
320+	defer server.Shutdown()
321+
322+	user := GenerateUser("anonymous")
323+
324+	client, err := user.NewClient()
325+	if err != nil {
326+		t.Fatalf("failed to connect: %v", err)
327+	}
328+	defer func() { _ = client.Close() }()
329+
330+	output, err := user.RunCommand(client, "ls")
331+	if err != nil {
332+		t.Logf("command error (expected): %v", err)
333+	}
334+
335+	if !strings.Contains(output, "access denied") {
336+		t.Errorf("expected 'access denied', got: %s", output)
337+	}
338+}
339+
340+func TestLs_AuthenticatedUser(t *testing.T) {
341+	server := NewTestSSHServer(t)
342+	defer server.Shutdown()
343+
344+	user := GenerateUser("alice")
345+	RegisterUserWithServer(server, user)
346+
347+	client, err := user.NewClient()
348+	if err != nil {
349+		t.Fatalf("failed to connect: %v", err)
350+	}
351+	defer func() { _ = client.Close() }()
352+
353+	output, err := user.RunCommand(client, "ls")
354+	if err != nil {
355+		t.Logf("command completed with: %v", err)
356+	}
357+
358+	if strings.Contains(output, "access denied") {
359+		t.Errorf("authenticated user should not get access denied, got: %s", output)
360+	}
361+
362+	if !strings.Contains(output, "no pubsub channels found") {
363+		t.Errorf("expected 'no pubsub channels found' for empty state, got: %s", output)
364+	}
365+}
366+
367+func TestPubSub_BasicFlow(t *testing.T) {
368+	server := NewTestSSHServer(t)
369+	defer server.Shutdown()
370+
371+	user := GenerateUser("alice")
372+	RegisterUserWithServer(server, user)
373+
374+	subClient, err := user.NewClient()
375+	if err != nil {
376+		t.Fatalf("failed to connect subscriber: %v", err)
377+	}
378+	defer func() { _ = subClient.Close() }()
379+
380+	pubClient, err := user.NewClient()
381+	if err != nil {
382+		t.Fatalf("failed to connect publisher: %v", err)
383+	}
384+	defer func() { _ = pubClient.Close() }()
385+
386+	subSession, err := subClient.NewSession()
387+	if err != nil {
388+		t.Fatalf("failed to create sub session: %v", err)
389+	}
390+	defer func() { _ = subSession.Close() }()
391+
392+	subStdout, err := subSession.StdoutPipe()
393+	if err != nil {
394+		t.Fatalf("failed to get sub stdout: %v", err)
395+	}
396+
397+	if err := subSession.Start("sub testtopic -c"); err != nil {
398+		t.Fatalf("failed to start sub: %v", err)
399+	}
400+
401+	time.Sleep(100 * time.Millisecond)
402+
403+	testMessage := "hello from pub"
404+	_, err = user.RunCommandWithStdin(pubClient, "pub testtopic -c", testMessage)
405+	if err != nil {
406+		t.Logf("pub command completed: %v", err)
407+	}
408+
409+	received := make([]byte, len(testMessage)+10)
410+	n, err := subStdout.Read(received)
411+	if err != nil && err != io.EOF {
412+		t.Logf("read error: %v", err)
413+	}
414+
415+	receivedStr := string(received[:n])
416+	if !strings.Contains(receivedStr, testMessage) {
417+		t.Errorf("subscriber did not receive message, got: %q, want: %q", receivedStr, testMessage)
418+	}
419+}
420+
421+func TestPubSub_PublicTopic(t *testing.T) {
422+	server := NewTestSSHServer(t)
423+	defer server.Shutdown()
424+
425+	alice := GenerateUser("alice")
426+	bob := GenerateUser("bob")
427+	RegisterUserWithServer(server, alice)
428+	RegisterUserWithServer(server, bob)
429+
430+	subClient, err := bob.NewClient()
431+	if err != nil {
432+		t.Fatalf("failed to connect subscriber: %v", err)
433+	}
434+	defer func() { _ = subClient.Close() }()
435+
436+	pubClient, err := alice.NewClient()
437+	if err != nil {
438+		t.Fatalf("failed to connect publisher: %v", err)
439+	}
440+	defer func() { _ = pubClient.Close() }()
441+
442+	subSession, err := subClient.NewSession()
443+	if err != nil {
444+		t.Fatalf("failed to create sub session: %v", err)
445+	}
446+	defer func() { _ = subSession.Close() }()
447+
448+	subStdout, err := subSession.StdoutPipe()
449+	if err != nil {
450+		t.Fatalf("failed to get sub stdout: %v", err)
451+	}
452+
453+	if err := subSession.Start("sub publictopic -p -c"); err != nil {
454+		t.Fatalf("failed to start sub: %v", err)
455+	}
456+
457+	time.Sleep(100 * time.Millisecond)
458+
459+	testMessage := "public message"
460+	_, err = alice.RunCommandWithStdin(pubClient, "pub publictopic -p -c", testMessage)
461+	if err != nil {
462+		t.Logf("pub command completed: %v", err)
463+	}
464+
465+	received := make([]byte, len(testMessage)+10)
466+	n, err := subStdout.Read(received)
467+	if err != nil && err != io.EOF {
468+		t.Logf("read error: %v", err)
469+	}
470+
471+	receivedStr := string(received[:n])
472+	if !strings.Contains(receivedStr, testMessage) {
473+		t.Errorf("subscriber did not receive public message, got: %q, want: %q", receivedStr, testMessage)
474+	}
475+}
476+
477+func TestPipe_Bidirectional(t *testing.T) {
478+	server := NewTestSSHServer(t)
479+	defer server.Shutdown()
480+
481+	alice := GenerateUser("alice")
482+	bob := GenerateUser("bob")
483+	RegisterUserWithServer(server, alice)
484+	RegisterUserWithServer(server, bob)
485+
486+	aliceClient, err := alice.NewClient()
487+	if err != nil {
488+		t.Fatalf("failed to connect alice: %v", err)
489+	}
490+	defer func() { _ = aliceClient.Close() }()
491+
492+	bobClient, err := bob.NewClient()
493+	if err != nil {
494+		t.Fatalf("failed to connect bob: %v", err)
495+	}
496+	defer func() { _ = bobClient.Close() }()
497+
498+	aliceSession, err := aliceClient.NewSession()
499+	if err != nil {
500+		t.Fatalf("failed to create alice session: %v", err)
501+	}
502+	defer func() { _ = aliceSession.Close() }()
503+
504+	aliceStdin, err := aliceSession.StdinPipe()
505+	if err != nil {
506+		t.Fatalf("failed to get alice stdin: %v", err)
507+	}
508+
509+	aliceStdout, err := aliceSession.StdoutPipe()
510+	if err != nil {
511+		t.Fatalf("failed to get alice stdout: %v", err)
512+	}
513+
514+	if err := aliceSession.Start("pipe pipetopic -p -c"); err != nil {
515+		t.Fatalf("failed to start alice pipe: %v", err)
516+	}
517+
518+	time.Sleep(100 * time.Millisecond)
519+
520+	bobSession, err := bobClient.NewSession()
521+	if err != nil {
522+		t.Fatalf("failed to create bob session: %v", err)
523+	}
524+	defer func() { _ = bobSession.Close() }()
525+
526+	bobStdin, err := bobSession.StdinPipe()
527+	if err != nil {
528+		t.Fatalf("failed to get bob stdin: %v", err)
529+	}
530+
531+	bobStdout, err := bobSession.StdoutPipe()
532+	if err != nil {
533+		t.Fatalf("failed to get bob stdout: %v", err)
534+	}
535+
536+	if err := bobSession.Start("pipe pipetopic -p -c"); err != nil {
537+		t.Fatalf("failed to start bob pipe: %v", err)
538+	}
539+
540+	time.Sleep(100 * time.Millisecond)
541+
542+	aliceMsg := "hello from alice\n"
543+	_, err = aliceStdin.Write([]byte(aliceMsg))
544+	if err != nil {
545+		t.Fatalf("alice failed to write: %v", err)
546+	}
547+
548+	bobReceived := make([]byte, 100)
549+	n, err := bobStdout.Read(bobReceived)
550+	if err != nil && err != io.EOF {
551+		t.Logf("bob read error: %v", err)
552+	}
553+	if !strings.Contains(string(bobReceived[:n]), "hello from alice") {
554+		t.Errorf("bob did not receive alice's message, got: %q", string(bobReceived[:n]))
555+	}
556+
557+	bobMsg := "hello from bob\n"
558+	_, err = bobStdin.Write([]byte(bobMsg))
559+	if err != nil {
560+		t.Fatalf("bob failed to write: %v", err)
561+	}
562+
563+	aliceReceived := make([]byte, 100)
564+	n, err = aliceStdout.Read(aliceReceived)
565+	if err != nil && err != io.EOF {
566+		t.Logf("alice read error: %v", err)
567+	}
568+	if !strings.Contains(string(aliceReceived[:n]), "hello from bob") {
569+		t.Errorf("alice did not receive bob's message, got: %q", string(aliceReceived[:n]))
570+	}
571+}
572+
573+func TestPipe_AutoGeneratedTopic(t *testing.T) {
574+	server := NewTestSSHServer(t)
575+	defer server.Shutdown()
576+
577+	user := GenerateUser("alice")
578+	RegisterUserWithServer(server, user)
579+
580+	client, err := user.NewClient()
581+	if err != nil {
582+		t.Fatalf("failed to connect: %v", err)
583+	}
584+	defer func() { _ = client.Close() }()
585+
586+	session, err := client.NewSession()
587+	if err != nil {
588+		t.Fatalf("failed to create session: %v", err)
589+	}
590+	defer func() { _ = session.Close() }()
591+
592+	stdout, err := session.StdoutPipe()
593+	if err != nil {
594+		t.Fatalf("failed to get stdout: %v", err)
595+	}
596+
597+	if err := session.Start("pipe"); err != nil {
598+		t.Fatalf("failed to start pipe: %v", err)
599+	}
600+
601+	received := make([]byte, 200)
602+	n, err := stdout.Read(received)
603+	if err != nil && err != io.EOF {
604+		t.Logf("read error: %v", err)
605+	}
606+
607+	output := string(received[:n])
608+	if !strings.Contains(output, "subscribe to this topic") {
609+		t.Errorf("expected topic subscription instructions, got: %q", output)
610+	}
611+}
612+
613+func TestAccessControl_AllowedUserViaFullPath(t *testing.T) {
614+	server := NewTestSSHServer(t)
615+	defer server.Shutdown()
616+
617+	alice := GenerateUser("alice")
618+	bob := GenerateUser("bob")
619+	RegisterUserWithServer(server, alice)
620+	RegisterUserWithServer(server, bob)
621+
622+	aliceClient, err := alice.NewClient()
623+	if err != nil {
624+		t.Fatalf("failed to connect alice: %v", err)
625+	}
626+	defer func() { _ = aliceClient.Close() }()
627+
628+	aliceSession, err := aliceClient.NewSession()
629+	if err != nil {
630+		t.Fatalf("failed to create alice session: %v", err)
631+	}
632+	defer func() { _ = aliceSession.Close() }()
633+
634+	aliceStdout, err := aliceSession.StdoutPipe()
635+	if err != nil {
636+		t.Fatalf("failed to get alice stdout: %v", err)
637+	}
638+
639+	if err := aliceSession.Start("sub sharedtopic -a alice,bob -c"); err != nil {
640+		t.Fatalf("failed to start alice sub: %v", err)
641+	}
642+
643+	time.Sleep(100 * time.Millisecond)
644+
645+	bobClient, err := bob.NewClient()
646+	if err != nil {
647+		t.Fatalf("failed to connect bob: %v", err)
648+	}
649+	defer func() { _ = bobClient.Close() }()
650+
651+	_, err = bob.RunCommandWithStdin(bobClient, "pub alice/sharedtopic -c", "bob allowed")
652+	if err != nil {
653+		t.Logf("bob pub completed: %v", err)
654+	}
655+
656+	aliceReceived := make([]byte, 100)
657+	n, _ := aliceStdout.Read(aliceReceived)
658+
659+	if !strings.Contains(string(aliceReceived[:n]), "bob allowed") {
660+		t.Errorf("alice should receive bob's message on shared topic, got: %q", string(aliceReceived[:n]))
661+	}
662+}