- commit
- 47413a9
- parent
- 1d500f9
- author
- Eric Bower
- date
- 2025-12-25 23:32:29 -0500 EST
chore(pipe): added tests
1 files changed,
+661,
-0
+661,
-0
1@@ -0,0 +1,661 @@
2+package pipe
3+
4+import (
5+ "context"
6+ "crypto/ed25519"
7+ "crypto/rand"
8+ "fmt"
9+ "io"
10+ "log/slog"
11+ "os"
12+ "strings"
13+ "testing"
14+ "time"
15+
16+ "github.com/antoniomika/syncmap"
17+ "github.com/picosh/pico/pkg/db"
18+ "github.com/picosh/pico/pkg/db/stub"
19+ "github.com/picosh/pico/pkg/pssh"
20+ "github.com/picosh/pico/pkg/shared"
21+ psub "github.com/picosh/pubsub"
22+ "github.com/picosh/utils"
23+ "github.com/prometheus/client_golang/prometheus"
24+ "golang.org/x/crypto/ssh"
25+)
26+
27+type TestDB struct {
28+ *stub.StubDB
29+ Users []*db.User
30+ Pubkeys []*db.PublicKey
31+ Features []*db.FeatureFlag
32+}
33+
34+func NewTestDB(logger *slog.Logger) *TestDB {
35+ return &TestDB{
36+ StubDB: stub.NewStubDB(logger),
37+ }
38+}
39+
40+func (t *TestDB) FindUserByPubkey(key string) (*db.User, error) {
41+ for _, pk := range t.Pubkeys {
42+ if pk.Key == key {
43+ return t.FindUser(pk.UserID)
44+ }
45+ }
46+ return nil, fmt.Errorf("user not found for pubkey")
47+}
48+
49+func (t *TestDB) FindUser(userID string) (*db.User, error) {
50+ for _, user := range t.Users {
51+ if user.ID == userID {
52+ return user, nil
53+ }
54+ }
55+ return nil, fmt.Errorf("user not found")
56+}
57+
58+func (t *TestDB) FindUserByName(name string) (*db.User, error) {
59+ for _, user := range t.Users {
60+ if user.Name == name {
61+ return user, nil
62+ }
63+ }
64+ return nil, fmt.Errorf("user not found")
65+}
66+
67+func (t *TestDB) FindFeature(userID, name string) (*db.FeatureFlag, error) {
68+ for _, ff := range t.Features {
69+ if ff.UserID == userID && ff.Name == name {
70+ return ff, nil
71+ }
72+ }
73+ return nil, fmt.Errorf("feature not found")
74+}
75+
76+func (t *TestDB) HasFeatureByUser(userID string, feature string) bool {
77+ ff, err := t.FindFeature(userID, feature)
78+ if err != nil {
79+ return false
80+ }
81+ return ff.IsValid()
82+}
83+
84+func (t *TestDB) InsertAccessLog(_ *db.AccessLog) error {
85+ return nil
86+}
87+
88+func (t *TestDB) Close() error {
89+ return nil
90+}
91+
92+func (t *TestDB) AddUser(user *db.User) {
93+ t.Users = append(t.Users, user)
94+}
95+
96+func (t *TestDB) AddPubkey(pubkey *db.PublicKey) {
97+ t.Pubkeys = append(t.Pubkeys, pubkey)
98+}
99+
100+type TestSSHServer struct {
101+ Cfg *shared.ConfigSite
102+ DBPool *TestDB
103+ Cancel context.CancelFunc
104+}
105+
106+func NewTestSSHServer(t *testing.T) *TestSSHServer {
107+ t.Helper()
108+
109+ opts := &slog.HandlerOptions{
110+ AddSource: true,
111+ Level: slog.LevelDebug,
112+ }
113+ logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
114+
115+ dbpool := NewTestDB(logger)
116+
117+ cfg := &shared.ConfigSite{
118+ Domain: "pipe.test",
119+ Port: "2222",
120+ PortOverride: "2222",
121+ Protocol: "ssh",
122+ Logger: logger,
123+ Space: "pipe",
124+ }
125+
126+ ctx, cancel := context.WithCancel(context.Background())
127+
128+ pubsub := psub.NewMulticast(logger)
129+ handler := &CliHandler{
130+ Logger: logger,
131+ DBPool: dbpool,
132+ PubSub: pubsub,
133+ Cfg: cfg,
134+ Waiters: syncmap.New[string, []string](),
135+ Access: syncmap.New[string, []string](),
136+ }
137+
138+ sshAuth := shared.NewSshAuthHandler(dbpool, logger, "pipe")
139+
140+ prometheus.DefaultRegisterer = prometheus.NewRegistry()
141+
142+ server, err := pssh.NewSSHServerWithConfig(
143+ ctx,
144+ logger,
145+ "pipe-ssh-test",
146+ "localhost",
147+ "2222",
148+ "9222",
149+ "../../ssh_data/term_info_ed25519",
150+ func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
151+ perms, _ := sshAuth.PubkeyAuthHandler(conn, key)
152+ if perms == nil {
153+ perms = &ssh.Permissions{
154+ Extensions: map[string]string{
155+ "pubkey": utils.KeyForKeyText(key),
156+ },
157+ }
158+ }
159+ return perms, nil
160+ },
161+ []pssh.SSHServerMiddleware{
162+ Middleware(handler),
163+ pssh.LogMiddleware(handler, dbpool),
164+ },
165+ nil,
166+ nil,
167+ )
168+
169+ if err != nil {
170+ t.Fatalf("failed to create ssh server: %v", err)
171+ }
172+
173+ go func() {
174+ if err := server.ListenAndServe(); err != nil {
175+ logger.Error("serve", "err", err.Error())
176+ }
177+ }()
178+
179+ time.Sleep(100 * time.Millisecond)
180+
181+ return &TestSSHServer{
182+ Cfg: cfg,
183+ DBPool: dbpool,
184+ Cancel: cancel,
185+ }
186+}
187+
188+func (s *TestSSHServer) Shutdown() {
189+ s.Cancel()
190+ time.Sleep(10 * time.Millisecond)
191+}
192+
193+type UserSSH struct {
194+ username string
195+ signer ssh.Signer
196+ privateKey []byte
197+}
198+
199+func GenerateUser(username string) UserSSH {
200+ _, userKey, err := ed25519.GenerateKey(rand.Reader)
201+ if err != nil {
202+ panic(err)
203+ }
204+
205+ b, err := ssh.MarshalPrivateKey(userKey, "")
206+ if err != nil {
207+ panic(err)
208+ }
209+
210+ userSigner, err := ssh.NewSignerFromKey(userKey)
211+ if err != nil {
212+ panic(err)
213+ }
214+
215+ return UserSSH{
216+ username: username,
217+ signer: userSigner,
218+ privateKey: b.Bytes,
219+ }
220+}
221+
222+func (u UserSSH) PublicKey() string {
223+ return utils.KeyForKeyText(u.signer.PublicKey())
224+}
225+
226+func (u UserSSH) NewClient() (*ssh.Client, error) {
227+ config := &ssh.ClientConfig{
228+ User: u.username,
229+ Auth: []ssh.AuthMethod{
230+ ssh.PublicKeys(u.signer),
231+ },
232+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
233+ }
234+
235+ return ssh.Dial("tcp", "localhost:2222", config)
236+}
237+
238+func (u UserSSH) RunCommand(client *ssh.Client, cmd string) (string, error) {
239+ session, err := client.NewSession()
240+ if err != nil {
241+ return "", err
242+ }
243+ defer func() { _ = session.Close() }()
244+
245+ stdoutPipe, err := session.StdoutPipe()
246+ if err != nil {
247+ return "", err
248+ }
249+
250+ stderrPipe, err := session.StderrPipe()
251+ if err != nil {
252+ return "", err
253+ }
254+
255+ if err := session.Start(cmd); err != nil {
256+ return "", err
257+ }
258+
259+ stdout := new(strings.Builder)
260+ stderr := new(strings.Builder)
261+ _, _ = io.Copy(stdout, stdoutPipe)
262+ _, _ = io.Copy(stderr, stderrPipe)
263+
264+ _ = session.Wait()
265+ return stdout.String() + stderr.String(), nil
266+}
267+
268+func (u UserSSH) RunCommandWithStdin(client *ssh.Client, cmd string, stdin string) (string, error) {
269+ session, err := client.NewSession()
270+ if err != nil {
271+ return "", err
272+ }
273+ defer func() { _ = session.Close() }()
274+
275+ stdinPipe, err := session.StdinPipe()
276+ if err != nil {
277+ return "", err
278+ }
279+
280+ stdoutPipe, err := session.StdoutPipe()
281+ if err != nil {
282+ return "", err
283+ }
284+
285+ if err := session.Start(cmd); err != nil {
286+ return "", err
287+ }
288+
289+ _, err = stdinPipe.Write([]byte(stdin))
290+ if err != nil {
291+ return "", err
292+ }
293+ _ = stdinPipe.Close()
294+
295+ buf := new(strings.Builder)
296+ _, err = io.Copy(buf, stdoutPipe)
297+ if err != nil {
298+ return "", err
299+ }
300+
301+ _ = session.Wait()
302+ return buf.String(), nil
303+}
304+
305+func RegisterUserWithServer(server *TestSSHServer, user UserSSH) {
306+ dbUser := &db.User{
307+ ID: user.username + "-id",
308+ Name: user.username,
309+ }
310+ server.DBPool.AddUser(dbUser)
311+ server.DBPool.AddPubkey(&db.PublicKey{
312+ ID: user.username + "-pubkey-id",
313+ UserID: dbUser.ID,
314+ Key: user.PublicKey(),
315+ })
316+}
317+
318+func TestLs_UnauthenticatedUserDenied(t *testing.T) {
319+ server := NewTestSSHServer(t)
320+ defer server.Shutdown()
321+
322+ user := GenerateUser("anonymous")
323+
324+ client, err := user.NewClient()
325+ if err != nil {
326+ t.Fatalf("failed to connect: %v", err)
327+ }
328+ defer func() { _ = client.Close() }()
329+
330+ output, err := user.RunCommand(client, "ls")
331+ if err != nil {
332+ t.Logf("command error (expected): %v", err)
333+ }
334+
335+ if !strings.Contains(output, "access denied") {
336+ t.Errorf("expected 'access denied', got: %s", output)
337+ }
338+}
339+
340+func TestLs_AuthenticatedUser(t *testing.T) {
341+ server := NewTestSSHServer(t)
342+ defer server.Shutdown()
343+
344+ user := GenerateUser("alice")
345+ RegisterUserWithServer(server, user)
346+
347+ client, err := user.NewClient()
348+ if err != nil {
349+ t.Fatalf("failed to connect: %v", err)
350+ }
351+ defer func() { _ = client.Close() }()
352+
353+ output, err := user.RunCommand(client, "ls")
354+ if err != nil {
355+ t.Logf("command completed with: %v", err)
356+ }
357+
358+ if strings.Contains(output, "access denied") {
359+ t.Errorf("authenticated user should not get access denied, got: %s", output)
360+ }
361+
362+ if !strings.Contains(output, "no pubsub channels found") {
363+ t.Errorf("expected 'no pubsub channels found' for empty state, got: %s", output)
364+ }
365+}
366+
367+func TestPubSub_BasicFlow(t *testing.T) {
368+ server := NewTestSSHServer(t)
369+ defer server.Shutdown()
370+
371+ user := GenerateUser("alice")
372+ RegisterUserWithServer(server, user)
373+
374+ subClient, err := user.NewClient()
375+ if err != nil {
376+ t.Fatalf("failed to connect subscriber: %v", err)
377+ }
378+ defer func() { _ = subClient.Close() }()
379+
380+ pubClient, err := user.NewClient()
381+ if err != nil {
382+ t.Fatalf("failed to connect publisher: %v", err)
383+ }
384+ defer func() { _ = pubClient.Close() }()
385+
386+ subSession, err := subClient.NewSession()
387+ if err != nil {
388+ t.Fatalf("failed to create sub session: %v", err)
389+ }
390+ defer func() { _ = subSession.Close() }()
391+
392+ subStdout, err := subSession.StdoutPipe()
393+ if err != nil {
394+ t.Fatalf("failed to get sub stdout: %v", err)
395+ }
396+
397+ if err := subSession.Start("sub testtopic -c"); err != nil {
398+ t.Fatalf("failed to start sub: %v", err)
399+ }
400+
401+ time.Sleep(100 * time.Millisecond)
402+
403+ testMessage := "hello from pub"
404+ _, err = user.RunCommandWithStdin(pubClient, "pub testtopic -c", testMessage)
405+ if err != nil {
406+ t.Logf("pub command completed: %v", err)
407+ }
408+
409+ received := make([]byte, len(testMessage)+10)
410+ n, err := subStdout.Read(received)
411+ if err != nil && err != io.EOF {
412+ t.Logf("read error: %v", err)
413+ }
414+
415+ receivedStr := string(received[:n])
416+ if !strings.Contains(receivedStr, testMessage) {
417+ t.Errorf("subscriber did not receive message, got: %q, want: %q", receivedStr, testMessage)
418+ }
419+}
420+
421+func TestPubSub_PublicTopic(t *testing.T) {
422+ server := NewTestSSHServer(t)
423+ defer server.Shutdown()
424+
425+ alice := GenerateUser("alice")
426+ bob := GenerateUser("bob")
427+ RegisterUserWithServer(server, alice)
428+ RegisterUserWithServer(server, bob)
429+
430+ subClient, err := bob.NewClient()
431+ if err != nil {
432+ t.Fatalf("failed to connect subscriber: %v", err)
433+ }
434+ defer func() { _ = subClient.Close() }()
435+
436+ pubClient, err := alice.NewClient()
437+ if err != nil {
438+ t.Fatalf("failed to connect publisher: %v", err)
439+ }
440+ defer func() { _ = pubClient.Close() }()
441+
442+ subSession, err := subClient.NewSession()
443+ if err != nil {
444+ t.Fatalf("failed to create sub session: %v", err)
445+ }
446+ defer func() { _ = subSession.Close() }()
447+
448+ subStdout, err := subSession.StdoutPipe()
449+ if err != nil {
450+ t.Fatalf("failed to get sub stdout: %v", err)
451+ }
452+
453+ if err := subSession.Start("sub publictopic -p -c"); err != nil {
454+ t.Fatalf("failed to start sub: %v", err)
455+ }
456+
457+ time.Sleep(100 * time.Millisecond)
458+
459+ testMessage := "public message"
460+ _, err = alice.RunCommandWithStdin(pubClient, "pub publictopic -p -c", testMessage)
461+ if err != nil {
462+ t.Logf("pub command completed: %v", err)
463+ }
464+
465+ received := make([]byte, len(testMessage)+10)
466+ n, err := subStdout.Read(received)
467+ if err != nil && err != io.EOF {
468+ t.Logf("read error: %v", err)
469+ }
470+
471+ receivedStr := string(received[:n])
472+ if !strings.Contains(receivedStr, testMessage) {
473+ t.Errorf("subscriber did not receive public message, got: %q, want: %q", receivedStr, testMessage)
474+ }
475+}
476+
477+func TestPipe_Bidirectional(t *testing.T) {
478+ server := NewTestSSHServer(t)
479+ defer server.Shutdown()
480+
481+ alice := GenerateUser("alice")
482+ bob := GenerateUser("bob")
483+ RegisterUserWithServer(server, alice)
484+ RegisterUserWithServer(server, bob)
485+
486+ aliceClient, err := alice.NewClient()
487+ if err != nil {
488+ t.Fatalf("failed to connect alice: %v", err)
489+ }
490+ defer func() { _ = aliceClient.Close() }()
491+
492+ bobClient, err := bob.NewClient()
493+ if err != nil {
494+ t.Fatalf("failed to connect bob: %v", err)
495+ }
496+ defer func() { _ = bobClient.Close() }()
497+
498+ aliceSession, err := aliceClient.NewSession()
499+ if err != nil {
500+ t.Fatalf("failed to create alice session: %v", err)
501+ }
502+ defer func() { _ = aliceSession.Close() }()
503+
504+ aliceStdin, err := aliceSession.StdinPipe()
505+ if err != nil {
506+ t.Fatalf("failed to get alice stdin: %v", err)
507+ }
508+
509+ aliceStdout, err := aliceSession.StdoutPipe()
510+ if err != nil {
511+ t.Fatalf("failed to get alice stdout: %v", err)
512+ }
513+
514+ if err := aliceSession.Start("pipe pipetopic -p -c"); err != nil {
515+ t.Fatalf("failed to start alice pipe: %v", err)
516+ }
517+
518+ time.Sleep(100 * time.Millisecond)
519+
520+ bobSession, err := bobClient.NewSession()
521+ if err != nil {
522+ t.Fatalf("failed to create bob session: %v", err)
523+ }
524+ defer func() { _ = bobSession.Close() }()
525+
526+ bobStdin, err := bobSession.StdinPipe()
527+ if err != nil {
528+ t.Fatalf("failed to get bob stdin: %v", err)
529+ }
530+
531+ bobStdout, err := bobSession.StdoutPipe()
532+ if err != nil {
533+ t.Fatalf("failed to get bob stdout: %v", err)
534+ }
535+
536+ if err := bobSession.Start("pipe pipetopic -p -c"); err != nil {
537+ t.Fatalf("failed to start bob pipe: %v", err)
538+ }
539+
540+ time.Sleep(100 * time.Millisecond)
541+
542+ aliceMsg := "hello from alice\n"
543+ _, err = aliceStdin.Write([]byte(aliceMsg))
544+ if err != nil {
545+ t.Fatalf("alice failed to write: %v", err)
546+ }
547+
548+ bobReceived := make([]byte, 100)
549+ n, err := bobStdout.Read(bobReceived)
550+ if err != nil && err != io.EOF {
551+ t.Logf("bob read error: %v", err)
552+ }
553+ if !strings.Contains(string(bobReceived[:n]), "hello from alice") {
554+ t.Errorf("bob did not receive alice's message, got: %q", string(bobReceived[:n]))
555+ }
556+
557+ bobMsg := "hello from bob\n"
558+ _, err = bobStdin.Write([]byte(bobMsg))
559+ if err != nil {
560+ t.Fatalf("bob failed to write: %v", err)
561+ }
562+
563+ aliceReceived := make([]byte, 100)
564+ n, err = aliceStdout.Read(aliceReceived)
565+ if err != nil && err != io.EOF {
566+ t.Logf("alice read error: %v", err)
567+ }
568+ if !strings.Contains(string(aliceReceived[:n]), "hello from bob") {
569+ t.Errorf("alice did not receive bob's message, got: %q", string(aliceReceived[:n]))
570+ }
571+}
572+
573+func TestPipe_AutoGeneratedTopic(t *testing.T) {
574+ server := NewTestSSHServer(t)
575+ defer server.Shutdown()
576+
577+ user := GenerateUser("alice")
578+ RegisterUserWithServer(server, user)
579+
580+ client, err := user.NewClient()
581+ if err != nil {
582+ t.Fatalf("failed to connect: %v", err)
583+ }
584+ defer func() { _ = client.Close() }()
585+
586+ session, err := client.NewSession()
587+ if err != nil {
588+ t.Fatalf("failed to create session: %v", err)
589+ }
590+ defer func() { _ = session.Close() }()
591+
592+ stdout, err := session.StdoutPipe()
593+ if err != nil {
594+ t.Fatalf("failed to get stdout: %v", err)
595+ }
596+
597+ if err := session.Start("pipe"); err != nil {
598+ t.Fatalf("failed to start pipe: %v", err)
599+ }
600+
601+ received := make([]byte, 200)
602+ n, err := stdout.Read(received)
603+ if err != nil && err != io.EOF {
604+ t.Logf("read error: %v", err)
605+ }
606+
607+ output := string(received[:n])
608+ if !strings.Contains(output, "subscribe to this topic") {
609+ t.Errorf("expected topic subscription instructions, got: %q", output)
610+ }
611+}
612+
613+func TestAccessControl_AllowedUserViaFullPath(t *testing.T) {
614+ server := NewTestSSHServer(t)
615+ defer server.Shutdown()
616+
617+ alice := GenerateUser("alice")
618+ bob := GenerateUser("bob")
619+ RegisterUserWithServer(server, alice)
620+ RegisterUserWithServer(server, bob)
621+
622+ aliceClient, err := alice.NewClient()
623+ if err != nil {
624+ t.Fatalf("failed to connect alice: %v", err)
625+ }
626+ defer func() { _ = aliceClient.Close() }()
627+
628+ aliceSession, err := aliceClient.NewSession()
629+ if err != nil {
630+ t.Fatalf("failed to create alice session: %v", err)
631+ }
632+ defer func() { _ = aliceSession.Close() }()
633+
634+ aliceStdout, err := aliceSession.StdoutPipe()
635+ if err != nil {
636+ t.Fatalf("failed to get alice stdout: %v", err)
637+ }
638+
639+ if err := aliceSession.Start("sub sharedtopic -a alice,bob -c"); err != nil {
640+ t.Fatalf("failed to start alice sub: %v", err)
641+ }
642+
643+ time.Sleep(100 * time.Millisecond)
644+
645+ bobClient, err := bob.NewClient()
646+ if err != nil {
647+ t.Fatalf("failed to connect bob: %v", err)
648+ }
649+ defer func() { _ = bobClient.Close() }()
650+
651+ _, err = bob.RunCommandWithStdin(bobClient, "pub alice/sharedtopic -c", "bob allowed")
652+ if err != nil {
653+ t.Logf("bob pub completed: %v", err)
654+ }
655+
656+ aliceReceived := make([]byte, 100)
657+ n, _ := aliceStdout.Read(aliceReceived)
658+
659+ if !strings.Contains(string(aliceReceived[:n]), "bob allowed") {
660+ t.Errorf("alice should receive bob's message on shared topic, got: %q", string(aliceReceived[:n]))
661+ }
662+}