Eric Bower
·
2026-02-26
server_test.go
1package pssh_test
2
3import (
4 "bytes"
5 "context"
6 "crypto/rand"
7 "errors"
8 "io"
9 "log/slog"
10 "net"
11 "reflect"
12 "slices"
13 "strings"
14 "testing"
15 "time"
16 "unsafe"
17
18 "github.com/picosh/pico/pkg/pssh"
19 "github.com/picosh/pico/pkg/shared"
20 "golang.org/x/crypto/ed25519"
21 "golang.org/x/crypto/ssh"
22)
23
24// MockChannel implements ssh.Channel for testing PTY line-ending normalization.
25type MockChannel struct {
26 data []byte
27}
28
29func (m *MockChannel) Read(data []byte) (n int, err error) {
30 return 0, io.EOF
31}
32
33func (m *MockChannel) Write(data []byte) (n int, err error) {
34 m.data = append(m.data, data...)
35 return len(data), nil
36}
37
38func (m *MockChannel) Close() error {
39 return nil
40}
41
42func (m *MockChannel) CloseWrite() error {
43 return nil
44}
45
46func (m *MockChannel) SendRequest(name string, wantReply bool, data []byte) (bool, error) {
47 return false, nil
48}
49
50func (m *MockChannel) Stderr() io.ReadWriter {
51 return &bytes.Buffer{}
52}
53
54func (m *MockChannel) Data() []byte {
55 return m.data
56}
57
58// setPtyField sets the private pty field on a session (test only).
59func setPtyField(session *pssh.SSHServerConnSession, pty *pssh.Pty) {
60 field := reflect.ValueOf(session).Elem().FieldByName("pty")
61 reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(pty))
62}
63
64// TestSSHServerConnSessionWritePtyLineEnding verifies that line-ending normalization works correctly.
65func TestSSHServerConnSessionWritePtyLineEnding(t *testing.T) {
66 ctx := context.Background()
67 logger := slog.Default()
68 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
69
70 // Create a mock SSH connection
71 sshConn := &ssh.ServerConn{}
72 serverConn := pssh.NewSSHServerConn(ctx, logger, sshConn, server)
73
74 // Create session with mock channel
75 mockChannel := &MockChannel{}
76
77 createSession := func() *pssh.SSHServerConnSession {
78 return &pssh.SSHServerConnSession{
79 Channel: mockChannel,
80 SSHServerConn: serverConn,
81 Ctx: ctx,
82 }
83 }
84
85 t.Run("no PTY - write as-is", func(t *testing.T) {
86 mockChannel.data = nil
87 session := createSession()
88 // No PTY is allocated, so behavior should write as-is
89
90 // Write text with just \n (no \r)
91 input := []byte("line1\nline2\nline3")
92 n, err := session.Write(input)
93
94 if err != nil {
95 t.Errorf("unexpected error: %v", err)
96 }
97 if n != len(input) {
98 t.Errorf("expected %d bytes written, got %d", len(input), n)
99 }
100 if !slices.Equal(mockChannel.data, input) {
101 t.Errorf("expected %q, got %q", string(input), string(mockChannel.data))
102 }
103 })
104
105 t.Run("with PTY - normalize bare newlines to CRLF", func(t *testing.T) {
106 mockChannel.data = nil
107 session := createSession()
108 // Set PTY on the session
109 pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
110 setPtyField(session, pty)
111
112 // Write text with just \n (no \r)
113 input := []byte("line1\nline2\nline3")
114 expected := []byte("line1\r\nline2\r\nline3")
115
116 n, err := session.Write(input)
117
118 if err != nil {
119 t.Errorf("unexpected error: %v", err)
120 }
121 // Should return original length
122 if n != len(input) {
123 t.Errorf("expected %d bytes written, got %d", len(input), n)
124 }
125 // Should write normalized data
126 if !slices.Equal(mockChannel.data, expected) {
127 t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
128 }
129 })
130
131 t.Run("with PTY - preserve existing CRLF", func(t *testing.T) {
132 mockChannel.data = nil
133 session := createSession()
134 pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
135 setPtyField(session, pty)
136
137 // Write text that already has proper \r\n
138 input := []byte("line1\r\nline2\r\nline3")
139 expected := []byte("line1\r\nline2\r\nline3") // Should not duplicate
140
141 n, err := session.Write(input)
142
143 if err != nil {
144 t.Errorf("unexpected error: %v", err)
145 }
146 if n != len(input) {
147 t.Errorf("expected %d bytes written, got %d", len(input), n)
148 }
149 if !slices.Equal(mockChannel.data, expected) {
150 t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
151 }
152 })
153
154 t.Run("with PTY - mixed newlines normalized correctly", func(t *testing.T) {
155 mockChannel.data = nil
156 session := createSession()
157 pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
158 setPtyField(session, pty)
159
160 // Mix of \n and \r\n
161 input := []byte("line1\nline2\r\nline3\nline4")
162 expected := []byte("line1\r\nline2\r\nline3\r\nline4")
163
164 n, err := session.Write(input)
165
166 if err != nil {
167 t.Errorf("unexpected error: %v", err)
168 }
169 if n != len(input) {
170 t.Errorf("expected %d bytes written, got %d", len(input), n)
171 }
172 if !slices.Equal(mockChannel.data, expected) {
173 t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
174 }
175 })
176
177 t.Run("staircase bug regression - sequential writes maintain formatting", func(t *testing.T) {
178 // This test simulates the staircase bug where multiple writes
179 // without proper CRLF would cause progressive indentation
180 mockChannel.data = nil
181 session := createSession()
182 pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
183 setPtyField(session, pty)
184
185 // Simulate help text being written in multiple chunks
186 writes := []string{
187 "NAME:\n",
188 "\tssh - A tool\n",
189 "\n",
190 "USAGE:\n",
191 "\tssh [options]\n",
192 }
193
194 for _, w := range writes {
195 _, err := session.Write([]byte(w))
196 if err != nil {
197 t.Errorf("unexpected error: %v", err)
198 }
199 }
200
201 // Check that every \n is preceded by \r to prevent staircase
202 output := mockChannel.data
203 for i := 0; i < len(output); i++ {
204 if output[i] == '\n' {
205 if i == 0 {
206 t.Errorf("newline at position 0 not preceded by carriage return")
207 } else if output[i-1] != '\r' {
208 t.Errorf("newline at position %d not preceded by carriage return, got %q before it",
209 i, string(output[i-1]))
210 }
211 }
212 }
213 })
214}
215
216func TestNewSSHServer(t *testing.T) {
217 ctx := context.Background()
218 logger := slog.Default()
219 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
220
221 if server == nil { //nolint:all
222 t.Fatal("expected non-nil server")
223 }
224
225 if server.Ctx == nil { //nolint:all
226 t.Error("expected non-nil context")
227 }
228
229 if server.Logger == nil { //nolint:all
230 t.Error("expected non-nil logger")
231 }
232
233 if server.Config == nil { //nolint:all
234 t.Error("expected non-nil config")
235 }
236
237 if server.Conns == nil { //nolint:all
238 t.Error("expected non-nil connections map")
239 }
240}
241
242func TestNewSSHServerConn(t *testing.T) {
243 ctx := context.Background()
244 logger := slog.Default()
245 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
246 conn := &ssh.ServerConn{}
247
248 serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
249
250 if serverConn == nil { //nolint:all
251 t.Fatal("expected non-nil server connection")
252 }
253
254 if serverConn.Ctx == nil { //nolint:all
255 t.Error("expected non-nil context")
256 }
257
258 if serverConn.Logger == nil { //nolint:all
259 t.Error("expected non-nil logger")
260 }
261
262 if serverConn.Conn != conn { //nolint:all
263 t.Error("expected conn to match")
264 }
265
266 if serverConn.SSHServer != server { //nolint:all
267 t.Error("expected server to match")
268 }
269}
270
271func TestSSHServerConnClose(t *testing.T) {
272 ctx := context.Background()
273 logger := slog.Default()
274 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
275 conn := &ssh.ServerConn{}
276
277 serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
278 err := serverConn.Close()
279
280 if err != nil {
281 t.Errorf("unexpected error: %v", err)
282 }
283
284 // Should be canceled after close
285 select {
286 case <-serverConn.Ctx.Done():
287 // Context was canceled as expected
288 default:
289 t.Error("context was not canceled after Close()")
290 }
291}
292
293func TestSSHServerClose(t *testing.T) {
294 ctx := context.Background()
295 logger := slog.Default()
296 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
297
298 // Create a mock listener to test Close()
299 listener, err := net.Listen("tcp", "127.0.0.1:0")
300 if err != nil {
301 t.Fatalf("failed to create listener: %v", err)
302 }
303
304 server.Listener = listener
305 err = server.Close()
306
307 if err != nil {
308 t.Errorf("unexpected error: %v", err)
309 }
310
311 // Should be canceled after close
312 select {
313 case <-server.Ctx.Done():
314 // Context was canceled as expected
315 default:
316 t.Error("context was not canceled after Close()")
317 }
318}
319
320func TestSSHServerNilParams(t *testing.T) {
321 // Test with nil context and logger
322 //nolint:staticcheck // SA1012 ignores nil check
323 //lint:ignore SA1012 ignores nil check
324 server := pssh.NewSSHServer(nil, nil, nil)
325
326 if server == nil { //nolint:all
327 t.Fatal("expected non-nil server")
328 }
329
330 if server.Ctx == nil { //nolint:all
331 t.Error("expected non-nil context even when nil is passed")
332 }
333
334 if server.Logger == nil { //nolint:all
335 t.Error("expected non-nil logger even when nil is passed")
336 }
337
338 // Test with nil context and logger for connection
339 //nolint:staticcheck // SA1012 ignores nil check
340 //lint:ignore SA1012 ignores nil check
341 conn := pssh.NewSSHServerConn(nil, nil, &ssh.ServerConn{}, server)
342
343 if conn == nil { //nolint:all
344 t.Fatal("expected non-nil server connection")
345 }
346
347 if conn.Ctx == nil { //nolint:all
348 t.Error("expected non-nil context even when nil is passed")
349 }
350
351 if conn.Logger == nil { //nolint:all
352 t.Error("expected non-nil logger even when nil is passed")
353 }
354}
355
356func TestSSHServerHandleConn(t *testing.T) {
357 ctx, cancel := context.WithCancel(context.Background())
358 defer cancel()
359 logger := slog.Default()
360 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
361
362 // Setup a basic SSH server config
363 config := &ssh.ServerConfig{
364 NoClientAuth: true,
365 }
366
367 server.Config.ServerConfig = config
368
369 // Create a mock connection
370 client, server_conn := net.Pipe()
371 defer func() {
372 _ = client.Close()
373 }()
374
375 // Start HandleConn in a goroutine
376 errChan := make(chan error, 1)
377 go func() {
378 errChan <- server.HandleConn(server_conn)
379 }()
380
381 // Configure SSH client
382 clientConfig := &ssh.ClientConfig{
383 User: "testuser",
384 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
385 }
386
387 // Try to establish SSH connection
388 _, _, _, err := ssh.NewClientConn(client, "", clientConfig)
389
390 // It should fail since we're using a pipe and not a proper SSH handshake
391 if err == nil {
392 t.Error("expected SSH handshake to fail with test pipe")
393 }
394
395 // Close connections to ensure HandleConn returns
396 _ = client.Close()
397 _ = server_conn.Close()
398
399 // Wait for HandleConn to return
400 select {
401 case <-errChan:
402 // Expected HandleConn to return
403 case <-time.After(2 * time.Second):
404 t.Error("HandleConn did not return after connection closed")
405 }
406}
407
408func TestSSHServerListenAndServe(t *testing.T) {
409 ctx, cancel := context.WithCancel(context.Background())
410 defer cancel()
411 logger := slog.Default()
412 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
413
414 config := &ssh.ServerConfig{
415 NoClientAuth: true,
416 }
417
418 // Set a random port
419 port := "127.0.0.1:0"
420 server.Config.ListenAddr = port
421 server.Config.ServerConfig = config
422
423 // Start server in a goroutine
424 errChan := make(chan error, 1)
425 go func() {
426 err := server.ListenAndServe()
427 errChan <- err
428 }()
429
430 // Wait a bit for the server to start
431 time.Sleep(100 * time.Millisecond)
432
433 // Trigger cancellation to stop the server
434 cancel()
435
436 // Wait for server to stop
437 select {
438 case err := <-errChan:
439 if err != nil && !errors.Is(err, net.ErrClosed) {
440 t.Errorf("unexpected error: %v", err)
441 }
442 case <-time.After(2 * time.Second):
443 t.Error("server did not shut down in time")
444 }
445}
446
447func TestSSHServerConnHandle(t *testing.T) {
448 ctx, cancel := context.WithCancel(context.Background())
449 defer cancel()
450 logger := slog.Default()
451 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
452 conn := &ssh.ServerConn{}
453
454 serverConn := pssh.NewSSHServerConn(ctx, logger, conn, server)
455
456 // Create channels for testing
457 chans := make(chan ssh.NewChannel)
458 reqs := make(chan *ssh.Request)
459
460 // Start handle in a goroutine
461 errChan := make(chan error, 1)
462 go func() {
463 errChan <- serverConn.Handle(chans, reqs)
464 }()
465
466 // Ensure handle returns when context is canceled
467 cancel()
468
469 // Wait for handle to return
470 select {
471 case err := <-errChan:
472 if err != nil {
473 t.Errorf("unexpected error: %v", err)
474 }
475 case <-time.After(2 * time.Second):
476 t.Error("Handle did not return after context canceled")
477 }
478}
479
480func TestSSHServerCommandParsing(t *testing.T) {
481 ctx, cancel := context.WithCancel(context.Background())
482 defer cancel()
483
484 logger := slog.Default()
485 var capturedCommand []string
486
487 user := GenerateKey()
488
489 // Use dynamic port (0) to avoid port conflicts
490 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{
491 ListenAddr: "127.0.0.1:0",
492 Middleware: []pssh.SSHServerMiddleware{
493 func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
494 return func(sesh *pssh.SSHServerConnSession) error {
495 capturedCommand = sesh.Command()
496 return next(sesh)
497 }
498 },
499 },
500 ServerConfig: &ssh.ServerConfig{
501 NoClientAuth: true,
502 NoClientAuthCallback: func(ssh.ConnMetadata) (*ssh.Permissions, error) {
503 return &ssh.Permissions{
504 Extensions: map[string]string{
505 "pubkey": shared.KeyForKeyText(user.signer.PublicKey()),
506 },
507 }, nil
508 },
509 },
510 })
511 server.Config.AddHostKey(user.signer)
512
513 // Start server in a goroutine
514 errChan := make(chan error, 1)
515 go func() {
516 err := server.ListenAndServe()
517 errChan <- err
518 }()
519
520 // Wait for server to be ready and get the actual listening address
521 var actualAddr string
522 for i := 0; i < 50; i++ {
523 server.Mu.Lock()
524 listener := server.Listener
525 server.Mu.Unlock()
526 if listener != nil {
527 actualAddr = listener.Addr().String()
528 break
529 }
530 time.Sleep(10 * time.Millisecond)
531 }
532
533 if actualAddr == "" {
534 t.Fatal("server listener not ready")
535 }
536
537 // Send command to server
538 _, _ = user.CmdAddr(nil, actualAddr, "accept --comment 'here we go' 101")
539
540 time.Sleep(100 * time.Millisecond)
541
542 expectedCommand := []string{"accept", "--comment", "here we go", "101"}
543 if !slices.Equal(expectedCommand, capturedCommand) {
544 t.Error("command not exected", capturedCommand, len(capturedCommand), expectedCommand, len(expectedCommand))
545 }
546
547 // Trigger cancellation to stop the server
548 cancel()
549
550 // Wait for server to stop
551 select {
552 case err := <-errChan:
553 if err != nil && !errors.Is(err, net.ErrClosed) {
554 t.Errorf("unexpected error: %v", err)
555 }
556 case <-time.After(2 * time.Second):
557 t.Error("server did not shut down in time")
558 }
559}
560
561type UserSSH struct {
562 username string
563 signer ssh.Signer
564}
565
566func NewUserSSH(username string, signer ssh.Signer) *UserSSH {
567 return &UserSSH{
568 username: username,
569 signer: signer,
570 }
571}
572
573func (s UserSSH) Public() string {
574 pubkey := s.signer.PublicKey()
575 return string(ssh.MarshalAuthorizedKey(pubkey))
576}
577
578func (s UserSSH) CmdAddr(patch []byte, addr string, cmd string) (string, error) {
579 config := &ssh.ClientConfig{
580 User: s.username,
581 Auth: []ssh.AuthMethod{
582 ssh.PublicKeys(s.signer),
583 },
584 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
585 }
586
587 client, err := ssh.Dial("tcp", addr, config)
588 if err != nil {
589 return "", err
590 }
591 defer func() {
592 _ = client.Close()
593 }()
594
595 session, err := client.NewSession()
596 if err != nil {
597 return "", err
598 }
599 defer func() {
600 _ = session.Close()
601 }()
602
603 stdinPipe, err := session.StdinPipe()
604 if err != nil {
605 return "", err
606 }
607
608 stdoutPipe, err := session.StdoutPipe()
609 if err != nil {
610 return "", err
611 }
612
613 if err := session.Start(cmd); err != nil {
614 return "", err
615 }
616
617 if patch != nil {
618 _, err = stdinPipe.Write(patch)
619 if err != nil {
620 return "", err
621 }
622 }
623
624 _ = stdinPipe.Close()
625
626 if err := session.Wait(); err != nil {
627 return "", err
628 }
629
630 buf := new(strings.Builder)
631 _, err = io.Copy(buf, stdoutPipe)
632 if err != nil {
633 return "", err
634 }
635
636 return buf.String(), nil
637}
638
639func GenerateKey() UserSSH {
640 _, userKey, err := ed25519.GenerateKey(rand.Reader)
641 if err != nil {
642 panic(err)
643 }
644
645 userSigner, err := ssh.NewSignerFromKey(userKey)
646 if err != nil {
647 panic(err)
648 }
649
650 return UserSSH{
651 username: "user",
652 signer: userSigner,
653 }
654}