repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pssh
Eric Bower  ·  2026-02-26

server_test.go

  1package pssh_test
  2
  3import (
  4	"bytes"
  5	"context"
  6	"crypto/rand"
  7	"errors"
  8	"io"
  9	"log/slog"
 10	"net"
 11	"reflect"
 12	"slices"
 13	"strings"
 14	"testing"
 15	"time"
 16	"unsafe"
 17
 18	"github.com/picosh/pico/pkg/pssh"
 19	"github.com/picosh/pico/pkg/shared"
 20	"golang.org/x/crypto/ed25519"
 21	"golang.org/x/crypto/ssh"
 22)
 23
 24// MockChannel implements ssh.Channel for testing PTY line-ending normalization.
 25type MockChannel struct {
 26	data []byte
 27}
 28
 29func (m *MockChannel) Read(data []byte) (n int, err error) {
 30	return 0, io.EOF
 31}
 32
 33func (m *MockChannel) Write(data []byte) (n int, err error) {
 34	m.data = append(m.data, data...)
 35	return len(data), nil
 36}
 37
 38func (m *MockChannel) Close() error {
 39	return nil
 40}
 41
 42func (m *MockChannel) CloseWrite() error {
 43	return nil
 44}
 45
 46func (m *MockChannel) SendRequest(name string, wantReply bool, data []byte) (bool, error) {
 47	return false, nil
 48}
 49
 50func (m *MockChannel) Stderr() io.ReadWriter {
 51	return &bytes.Buffer{}
 52}
 53
 54func (m *MockChannel) Data() []byte {
 55	return m.data
 56}
 57
 58// setPtyField sets the private pty field on a session (test only).
 59func setPtyField(session *pssh.SSHServerConnSession, pty *pssh.Pty) {
 60	field := reflect.ValueOf(session).Elem().FieldByName("pty")
 61	reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(pty))
 62}
 63
 64// TestSSHServerConnSessionWritePtyLineEnding verifies that line-ending normalization works correctly.
 65func TestSSHServerConnSessionWritePtyLineEnding(t *testing.T) {
 66	ctx := context.Background()
 67	logger := slog.Default()
 68	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
 69
 70	// Create a mock SSH connection
 71	sshConn := &ssh.ServerConn{}
 72	serverConn := pssh.NewSSHServerConn(ctx, logger, sshConn, server)
 73
 74	// Create session with mock channel
 75	mockChannel := &MockChannel{}
 76
 77	createSession := func() *pssh.SSHServerConnSession {
 78		return &pssh.SSHServerConnSession{
 79			Channel:       mockChannel,
 80			SSHServerConn: serverConn,
 81			Ctx:           ctx,
 82		}
 83	}
 84
 85	t.Run("no PTY - write as-is", func(t *testing.T) {
 86		mockChannel.data = nil
 87		session := createSession()
 88		// No PTY is allocated, so behavior should write as-is
 89
 90		// Write text with just \n (no \r)
 91		input := []byte("line1\nline2\nline3")
 92		n, err := session.Write(input)
 93
 94		if err != nil {
 95			t.Errorf("unexpected error: %v", err)
 96		}
 97		if n != len(input) {
 98			t.Errorf("expected %d bytes written, got %d", len(input), n)
 99		}
100		if !slices.Equal(mockChannel.data, input) {
101			t.Errorf("expected %q, got %q", string(input), string(mockChannel.data))
102		}
103	})
104
105	t.Run("with PTY - normalize bare newlines to CRLF", func(t *testing.T) {
106		mockChannel.data = nil
107		session := createSession()
108		// Set PTY on the session
109		pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
110		setPtyField(session, pty)
111
112		// Write text with just \n (no \r)
113		input := []byte("line1\nline2\nline3")
114		expected := []byte("line1\r\nline2\r\nline3")
115
116		n, err := session.Write(input)
117
118		if err != nil {
119			t.Errorf("unexpected error: %v", err)
120		}
121		// Should return original length
122		if n != len(input) {
123			t.Errorf("expected %d bytes written, got %d", len(input), n)
124		}
125		// Should write normalized data
126		if !slices.Equal(mockChannel.data, expected) {
127			t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
128		}
129	})
130
131	t.Run("with PTY - preserve existing CRLF", func(t *testing.T) {
132		mockChannel.data = nil
133		session := createSession()
134		pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
135		setPtyField(session, pty)
136
137		// Write text that already has proper \r\n
138		input := []byte("line1\r\nline2\r\nline3")
139		expected := []byte("line1\r\nline2\r\nline3") // Should not duplicate
140
141		n, err := session.Write(input)
142
143		if err != nil {
144			t.Errorf("unexpected error: %v", err)
145		}
146		if n != len(input) {
147			t.Errorf("expected %d bytes written, got %d", len(input), n)
148		}
149		if !slices.Equal(mockChannel.data, expected) {
150			t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
151		}
152	})
153
154	t.Run("with PTY - mixed newlines normalized correctly", func(t *testing.T) {
155		mockChannel.data = nil
156		session := createSession()
157		pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
158		setPtyField(session, pty)
159
160		// Mix of \n and \r\n
161		input := []byte("line1\nline2\r\nline3\nline4")
162		expected := []byte("line1\r\nline2\r\nline3\r\nline4")
163
164		n, err := session.Write(input)
165
166		if err != nil {
167			t.Errorf("unexpected error: %v", err)
168		}
169		if n != len(input) {
170			t.Errorf("expected %d bytes written, got %d", len(input), n)
171		}
172		if !slices.Equal(mockChannel.data, expected) {
173			t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
174		}
175	})
176
177	t.Run("staircase bug regression - sequential writes maintain formatting", func(t *testing.T) {
178		// This test simulates the staircase bug where multiple writes
179		// without proper CRLF would cause progressive indentation
180		mockChannel.data = nil
181		session := createSession()
182		pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
183		setPtyField(session, pty)
184
185		// Simulate help text being written in multiple chunks
186		writes := []string{
187			"NAME:\n",
188			"\tssh - A tool\n",
189			"\n",
190			"USAGE:\n",
191			"\tssh [options]\n",
192		}
193
194		for _, w := range writes {
195			_, err := session.Write([]byte(w))
196			if err != nil {
197				t.Errorf("unexpected error: %v", err)
198			}
199		}
200
201		// Check that every \n is preceded by \r to prevent staircase
202		output := mockChannel.data
203		for i := 0; i < len(output); i++ {
204			if output[i] == '\n' {
205				if i == 0 {
206					t.Errorf("newline at position 0 not preceded by carriage return")
207				} else if output[i-1] != '\r' {
208					t.Errorf("newline at position %d not preceded by carriage return, got %q before it",
209						i, string(output[i-1]))
210				}
211			}
212		}
213	})
214}
215
216func TestNewSSHServer(t *testing.T) {
217	ctx := context.Background()
218	logger := slog.Default()
219	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
220
221	if server == nil { //nolint:all
222		t.Fatal("expected non-nil server")
223	}
224
225	if server.Ctx == nil { //nolint:all
226		t.Error("expected non-nil context")
227	}
228
229	if server.Logger == nil { //nolint:all
230		t.Error("expected non-nil logger")
231	}
232
233	if server.Config == nil { //nolint:all
234		t.Error("expected non-nil config")
235	}
236
237	if server.Conns == nil { //nolint:all
238		t.Error("expected non-nil connections map")
239	}
240}
241
242func TestNewSSHServerConn(t *testing.T) {
243	ctx := context.Background()
244	logger := slog.Default()
245	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
246	conn := &ssh.ServerConn{}
247
248	serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
249
250	if serverConn == nil { //nolint:all
251		t.Fatal("expected non-nil server connection")
252	}
253
254	if serverConn.Ctx == nil { //nolint:all
255		t.Error("expected non-nil context")
256	}
257
258	if serverConn.Logger == nil { //nolint:all
259		t.Error("expected non-nil logger")
260	}
261
262	if serverConn.Conn != conn { //nolint:all
263		t.Error("expected conn to match")
264	}
265
266	if serverConn.SSHServer != server { //nolint:all
267		t.Error("expected server to match")
268	}
269}
270
271func TestSSHServerConnClose(t *testing.T) {
272	ctx := context.Background()
273	logger := slog.Default()
274	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
275	conn := &ssh.ServerConn{}
276
277	serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
278	err := serverConn.Close()
279
280	if err != nil {
281		t.Errorf("unexpected error: %v", err)
282	}
283
284	// Should be canceled after close
285	select {
286	case <-serverConn.Ctx.Done():
287		// Context was canceled as expected
288	default:
289		t.Error("context was not canceled after Close()")
290	}
291}
292
293func TestSSHServerClose(t *testing.T) {
294	ctx := context.Background()
295	logger := slog.Default()
296	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
297
298	// Create a mock listener to test Close()
299	listener, err := net.Listen("tcp", "127.0.0.1:0")
300	if err != nil {
301		t.Fatalf("failed to create listener: %v", err)
302	}
303
304	server.Listener = listener
305	err = server.Close()
306
307	if err != nil {
308		t.Errorf("unexpected error: %v", err)
309	}
310
311	// Should be canceled after close
312	select {
313	case <-server.Ctx.Done():
314		// Context was canceled as expected
315	default:
316		t.Error("context was not canceled after Close()")
317	}
318}
319
320func TestSSHServerNilParams(t *testing.T) {
321	// Test with nil context and logger
322	//nolint:staticcheck // SA1012 ignores nil check
323	//lint:ignore SA1012 ignores nil check
324	server := pssh.NewSSHServer(nil, nil, nil)
325
326	if server == nil { //nolint:all
327		t.Fatal("expected non-nil server")
328	}
329
330	if server.Ctx == nil { //nolint:all
331		t.Error("expected non-nil context even when nil is passed")
332	}
333
334	if server.Logger == nil { //nolint:all
335		t.Error("expected non-nil logger even when nil is passed")
336	}
337
338	// Test with nil context and logger for connection
339	//nolint:staticcheck // SA1012 ignores nil check
340	//lint:ignore SA1012 ignores nil check
341	conn := pssh.NewSSHServerConn(nil, nil, &ssh.ServerConn{}, server)
342
343	if conn == nil { //nolint:all
344		t.Fatal("expected non-nil server connection")
345	}
346
347	if conn.Ctx == nil { //nolint:all
348		t.Error("expected non-nil context even when nil is passed")
349	}
350
351	if conn.Logger == nil { //nolint:all
352		t.Error("expected non-nil logger even when nil is passed")
353	}
354}
355
356func TestSSHServerHandleConn(t *testing.T) {
357	ctx, cancel := context.WithCancel(context.Background())
358	defer cancel()
359	logger := slog.Default()
360	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
361
362	// Setup a basic SSH server config
363	config := &ssh.ServerConfig{
364		NoClientAuth: true,
365	}
366
367	server.Config.ServerConfig = config
368
369	// Create a mock connection
370	client, server_conn := net.Pipe()
371	defer func() {
372		_ = client.Close()
373	}()
374
375	// Start HandleConn in a goroutine
376	errChan := make(chan error, 1)
377	go func() {
378		errChan <- server.HandleConn(server_conn)
379	}()
380
381	// Configure SSH client
382	clientConfig := &ssh.ClientConfig{
383		User:            "testuser",
384		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
385	}
386
387	// Try to establish SSH connection
388	_, _, _, err := ssh.NewClientConn(client, "", clientConfig)
389
390	// It should fail since we're using a pipe and not a proper SSH handshake
391	if err == nil {
392		t.Error("expected SSH handshake to fail with test pipe")
393	}
394
395	// Close connections to ensure HandleConn returns
396	_ = client.Close()
397	_ = server_conn.Close()
398
399	// Wait for HandleConn to return
400	select {
401	case <-errChan:
402		// Expected HandleConn to return
403	case <-time.After(2 * time.Second):
404		t.Error("HandleConn did not return after connection closed")
405	}
406}
407
408func TestSSHServerListenAndServe(t *testing.T) {
409	ctx, cancel := context.WithCancel(context.Background())
410	defer cancel()
411	logger := slog.Default()
412	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
413
414	config := &ssh.ServerConfig{
415		NoClientAuth: true,
416	}
417
418	// Set a random port
419	port := "127.0.0.1:0"
420	server.Config.ListenAddr = port
421	server.Config.ServerConfig = config
422
423	// Start server in a goroutine
424	errChan := make(chan error, 1)
425	go func() {
426		err := server.ListenAndServe()
427		errChan <- err
428	}()
429
430	// Wait a bit for the server to start
431	time.Sleep(100 * time.Millisecond)
432
433	// Trigger cancellation to stop the server
434	cancel()
435
436	// Wait for server to stop
437	select {
438	case err := <-errChan:
439		if err != nil && !errors.Is(err, net.ErrClosed) {
440			t.Errorf("unexpected error: %v", err)
441		}
442	case <-time.After(2 * time.Second):
443		t.Error("server did not shut down in time")
444	}
445}
446
447func TestSSHServerConnHandle(t *testing.T) {
448	ctx, cancel := context.WithCancel(context.Background())
449	defer cancel()
450	logger := slog.Default()
451	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
452	conn := &ssh.ServerConn{}
453
454	serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
455
456	// Create channels for testing
457	chans := make(chan ssh.NewChannel)
458	reqs := make(chan *ssh.Request)
459
460	// Start handle in a goroutine
461	errChan := make(chan error, 1)
462	go func() {
463		errChan <- serverConn.Handle(chans, reqs)
464	}()
465
466	// Ensure handle returns when context is canceled
467	cancel()
468
469	// Wait for handle to return
470	select {
471	case err := <-errChan:
472		if err != nil {
473			t.Errorf("unexpected error: %v", err)
474		}
475	case <-time.After(2 * time.Second):
476		t.Error("Handle did not return after context canceled")
477	}
478}
479
480func TestSSHServerCommandParsing(t *testing.T) {
481	ctx, cancel := context.WithCancel(context.Background())
482	defer cancel()
483
484	logger := slog.Default()
485	var capturedCommand []string
486
487	user := GenerateKey()
488
489	// Use dynamic port (0) to avoid port conflicts
490	server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{
491		ListenAddr: "127.0.0.1:0",
492		Middleware: []pssh.SSHServerMiddleware{
493			func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
494				return func(sesh *pssh.SSHServerConnSession) error {
495					capturedCommand = sesh.Command()
496					return next(sesh)
497				}
498			},
499		},
500		ServerConfig: &ssh.ServerConfig{
501			NoClientAuth: true,
502			NoClientAuthCallback: func(ssh.ConnMetadata) (*ssh.Permissions, error) {
503				return &ssh.Permissions{
504					Extensions: map[string]string{
505						"pubkey": shared.KeyForKeyText(user.signer.PublicKey()),
506					},
507				}, nil
508			},
509		},
510	})
511	server.Config.AddHostKey(user.signer)
512
513	// Start server in a goroutine
514	errChan := make(chan error, 1)
515	go func() {
516		err := server.ListenAndServe()
517		errChan <- err
518	}()
519
520	// Wait for server to be ready and get the actual listening address
521	var actualAddr string
522	for i := 0; i < 50; i++ {
523		server.Mu.Lock()
524		listener := server.Listener
525		server.Mu.Unlock()
526		if listener != nil {
527			actualAddr = listener.Addr().String()
528			break
529		}
530		time.Sleep(10 * time.Millisecond)
531	}
532
533	if actualAddr == "" {
534		t.Fatal("server listener not ready")
535	}
536
537	// Send command to server
538	_, _ = user.CmdAddr(nil, actualAddr, "accept --comment 'here we go' 101")
539
540	time.Sleep(100 * time.Millisecond)
541
542	expectedCommand := []string{"accept", "--comment", "here we go", "101"}
543	if !slices.Equal(expectedCommand, capturedCommand) {
544		t.Error("command not exected", capturedCommand, len(capturedCommand), expectedCommand, len(expectedCommand))
545	}
546
547	// Trigger cancellation to stop the server
548	cancel()
549
550	// Wait for server to stop
551	select {
552	case err := <-errChan:
553		if err != nil && !errors.Is(err, net.ErrClosed) {
554			t.Errorf("unexpected error: %v", err)
555		}
556	case <-time.After(2 * time.Second):
557		t.Error("server did not shut down in time")
558	}
559}
560
561type UserSSH struct {
562	username string
563	signer   ssh.Signer
564}
565
566func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
567	return &UserSSH{
568		username: username,
569		signer:   signer,
570	}
571}
572
573func (s UserSSH) Public() string {
574	pubkey := s.signer.PublicKey()
575	return string(ssh.MarshalAuthorizedKey(pubkey))
576}
577
578func (s UserSSH) CmdAddr(patch []byte, addr string, cmd string) (string, error) {
579	config := &ssh.ClientConfig{
580		User: s.username,
581		Auth: []ssh.AuthMethod{
582			ssh.PublicKeys(s.signer),
583		},
584		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
585	}
586
587	client, err := ssh.Dial("tcp", addr, config)
588	if err != nil {
589		return "", err
590	}
591	defer func() {
592		_ = client.Close()
593	}()
594
595	session, err := client.NewSession()
596	if err != nil {
597		return "", err
598	}
599	defer func() {
600		_ = session.Close()
601	}()
602
603	stdinPipe, err := session.StdinPipe()
604	if err != nil {
605		return "", err
606	}
607
608	stdoutPipe, err := session.StdoutPipe()
609	if err != nil {
610		return "", err
611	}
612
613	if err := session.Start(cmd); err != nil {
614		return "", err
615	}
616
617	if patch != nil {
618		_, err = stdinPipe.Write(patch)
619		if err != nil {
620			return "", err
621		}
622	}
623
624	_ = stdinPipe.Close()
625
626	if err := session.Wait(); err != nil {
627		return "", err
628	}
629
630	buf := new(strings.Builder)
631	_, err = io.Copy(buf, stdoutPipe)
632	if err != nil {
633		return "", err
634	}
635
636	return buf.String(), nil
637}
638
639func GenerateKey() UserSSH {
640	_, userKey, err := ed25519.GenerateKey(rand.Reader)
641	if err != nil {
642		panic(err)
643	}
644
645	userSigner, err := ssh.NewSignerFromKey(userKey)
646	if err != nil {
647		panic(err)
648	}
649
650	return UserSSH{
651		username: "user",
652		signer:   userSigner,
653	}
654}