- commit
- 740c00a
- parent
- da9a246
- author
- Eric Bower
- date
- 2026-02-26 09:04:53 -0500 EST
fix(pssh): normalize line endings for pty enabled ssh output
2 files changed,
+228,
-21
+15,
-5
1@@ -197,12 +197,22 @@ func (s *SSHServerConnSession) Write(p []byte) (n int, err error) {
2 return s.Channel.Write(p)
3 }
4
5- // When PTY is active, normalize line endings like a real terminal would.
6- // Replace \n with \r\n, but avoid double \r\n.
7- normalized := bytes.ReplaceAll(p, []byte{'\n'}, []byte{'\r', '\n'})
8- normalized = bytes.ReplaceAll(normalized, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'})
9+ // When PTY is active, ensure every \n is preceded by \r.
10+ // This ensures the cursor returns to column 0 before the newline.
11+ var buf bytes.Buffer
12+ for i := 0; i < len(p); i++ {
13+ if p[i] == '\n' {
14+ // Check if preceded by \r
15+ if i == 0 || p[i-1] != '\r' {
16+ buf.WriteByte('\r')
17+ }
18+ buf.WriteByte('\n')
19+ } else {
20+ buf.WriteByte(p[i])
21+ }
22+ }
23
24- // Write the normalized data
25+ normalized := buf.Bytes()
26 written, err := s.Channel.Write(normalized)
27
28 // Return the count based on original data length, not normalized
+213,
-16
1@@ -1,16 +1,19 @@
2 package pssh_test
3
4 import (
5+ "bytes"
6 "context"
7 "crypto/rand"
8 "errors"
9 "io"
10 "log/slog"
11 "net"
12+ "reflect"
13 "slices"
14 "strings"
15 "testing"
16 "time"
17+ "unsafe"
18
19 "github.com/picosh/pico/pkg/pssh"
20 "github.com/picosh/pico/pkg/shared"
21@@ -18,6 +21,198 @@ import (
22 "golang.org/x/crypto/ssh"
23 )
24
25+// MockChannel implements ssh.Channel for testing PTY line-ending normalization.
26+type MockChannel struct {
27+ data []byte
28+}
29+
30+func (m *MockChannel) Read(data []byte) (n int, err error) {
31+ return 0, io.EOF
32+}
33+
34+func (m *MockChannel) Write(data []byte) (n int, err error) {
35+ m.data = append(m.data, data...)
36+ return len(data), nil
37+}
38+
39+func (m *MockChannel) Close() error {
40+ return nil
41+}
42+
43+func (m *MockChannel) CloseWrite() error {
44+ return nil
45+}
46+
47+func (m *MockChannel) SendRequest(name string, wantReply bool, data []byte) (bool, error) {
48+ return false, nil
49+}
50+
51+func (m *MockChannel) Stderr() io.ReadWriter {
52+ return &bytes.Buffer{}
53+}
54+
55+func (m *MockChannel) Data() []byte {
56+ return m.data
57+}
58+
59+// setPtyField sets the private pty field on a session (test only).
60+func setPtyField(session *pssh.SSHServerConnSession, pty *pssh.Pty) {
61+ field := reflect.ValueOf(session).Elem().FieldByName("pty")
62+ reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(pty))
63+}
64+
65+// TestSSHServerConnSessionWritePtyLineEnding verifies that line-ending normalization works correctly.
66+func TestSSHServerConnSessionWritePtyLineEnding(t *testing.T) {
67+ ctx := context.Background()
68+ logger := slog.Default()
69+ server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{})
70+
71+ // Create a mock SSH connection
72+ sshConn := &ssh.ServerConn{}
73+ serverConn := pssh.NewSSHServerConn(ctx, logger, sshConn, server)
74+
75+ // Create session with mock channel
76+ mockChannel := &MockChannel{}
77+
78+ createSession := func() *pssh.SSHServerConnSession {
79+ return &pssh.SSHServerConnSession{
80+ Channel: mockChannel,
81+ SSHServerConn: serverConn,
82+ Ctx: ctx,
83+ }
84+ }
85+
86+ t.Run("no PTY - write as-is", func(t *testing.T) {
87+ mockChannel.data = nil
88+ session := createSession()
89+ // No PTY is allocated, so behavior should write as-is
90+
91+ // Write text with just \n (no \r)
92+ input := []byte("line1\nline2\nline3")
93+ n, err := session.Write(input)
94+
95+ if err != nil {
96+ t.Errorf("unexpected error: %v", err)
97+ }
98+ if n != len(input) {
99+ t.Errorf("expected %d bytes written, got %d", len(input), n)
100+ }
101+ if !slices.Equal(mockChannel.data, input) {
102+ t.Errorf("expected %q, got %q", string(input), string(mockChannel.data))
103+ }
104+ })
105+
106+ t.Run("with PTY - normalize bare newlines to CRLF", func(t *testing.T) {
107+ mockChannel.data = nil
108+ session := createSession()
109+ // Set PTY on the session
110+ pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
111+ setPtyField(session, pty)
112+
113+ // Write text with just \n (no \r)
114+ input := []byte("line1\nline2\nline3")
115+ expected := []byte("line1\r\nline2\r\nline3")
116+
117+ n, err := session.Write(input)
118+
119+ if err != nil {
120+ t.Errorf("unexpected error: %v", err)
121+ }
122+ // Should return original length
123+ if n != len(input) {
124+ t.Errorf("expected %d bytes written, got %d", len(input), n)
125+ }
126+ // Should write normalized data
127+ if !slices.Equal(mockChannel.data, expected) {
128+ t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
129+ }
130+ })
131+
132+ t.Run("with PTY - preserve existing CRLF", func(t *testing.T) {
133+ mockChannel.data = nil
134+ session := createSession()
135+ pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
136+ setPtyField(session, pty)
137+
138+ // Write text that already has proper \r\n
139+ input := []byte("line1\r\nline2\r\nline3")
140+ expected := []byte("line1\r\nline2\r\nline3") // Should not duplicate
141+
142+ n, err := session.Write(input)
143+
144+ if err != nil {
145+ t.Errorf("unexpected error: %v", err)
146+ }
147+ if n != len(input) {
148+ t.Errorf("expected %d bytes written, got %d", len(input), n)
149+ }
150+ if !slices.Equal(mockChannel.data, expected) {
151+ t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
152+ }
153+ })
154+
155+ t.Run("with PTY - mixed newlines normalized correctly", func(t *testing.T) {
156+ mockChannel.data = nil
157+ session := createSession()
158+ pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
159+ setPtyField(session, pty)
160+
161+ // Mix of \n and \r\n
162+ input := []byte("line1\nline2\r\nline3\nline4")
163+ expected := []byte("line1\r\nline2\r\nline3\r\nline4")
164+
165+ n, err := session.Write(input)
166+
167+ if err != nil {
168+ t.Errorf("unexpected error: %v", err)
169+ }
170+ if n != len(input) {
171+ t.Errorf("expected %d bytes written, got %d", len(input), n)
172+ }
173+ if !slices.Equal(mockChannel.data, expected) {
174+ t.Errorf("expected %q, got %q", string(expected), string(mockChannel.data))
175+ }
176+ })
177+
178+ t.Run("staircase bug regression - sequential writes maintain formatting", func(t *testing.T) {
179+ // This test simulates the staircase bug where multiple writes
180+ // without proper CRLF would cause progressive indentation
181+ mockChannel.data = nil
182+ session := createSession()
183+ pty := &pssh.Pty{Term: "xterm", Window: pssh.Window{Width: 80, Height: 24}}
184+ setPtyField(session, pty)
185+
186+ // Simulate help text being written in multiple chunks
187+ writes := []string{
188+ "NAME:\n",
189+ "\tssh - A tool\n",
190+ "\n",
191+ "USAGE:\n",
192+ "\tssh [options]\n",
193+ }
194+
195+ for _, w := range writes {
196+ _, err := session.Write([]byte(w))
197+ if err != nil {
198+ t.Errorf("unexpected error: %v", err)
199+ }
200+ }
201+
202+ // Check that every \n is preceded by \r to prevent staircase
203+ output := mockChannel.data
204+ for i := 0; i < len(output); i++ {
205+ if output[i] == '\n' {
206+ if i == 0 {
207+ t.Errorf("newline at position 0 not preceded by carriage return")
208+ } else if output[i-1] != '\r' {
209+ t.Errorf("newline at position %d not preceded by carriage return, got %q before it",
210+ i, string(output[i-1]))
211+ }
212+ }
213+ }
214+ })
215+}
216+
217 func TestNewSSHServer(t *testing.T) {
218 ctx := context.Background()
219 logger := slog.Default()
220@@ -291,8 +486,9 @@ func TestSSHServerCommandParsing(t *testing.T) {
221
222 user := GenerateKey()
223
224+ // Use dynamic port (0) to avoid port conflicts
225 server := pssh.NewSSHServer(ctx, logger, &pssh.SSHServerConfig{
226- ListenAddr: "localhost:2222",
227+ ListenAddr: "127.0.0.1:0",
228 Middleware: []pssh.SSHServerMiddleware{
229 func(next pssh.SSHServerHandler) pssh.SSHServerHandler {
230 return func(sesh *pssh.SSHServerConnSession) error {
231@@ -321,11 +517,22 @@ func TestSSHServerCommandParsing(t *testing.T) {
232 errChan <- err
233 }()
234
235- // Wait a bit for the server to start
236- time.Sleep(100 * time.Millisecond)
237+ // Wait for server to be ready and get the actual listening address
238+ var actualAddr string
239+ for i := 0; i < 50; i++ {
240+ if server.Listener != nil {
241+ actualAddr = server.Listener.Addr().String()
242+ break
243+ }
244+ time.Sleep(10 * time.Millisecond)
245+ }
246+
247+ if actualAddr == "" {
248+ t.Fatal("server listener not ready")
249+ }
250
251 // Send command to server
252- _, _ = user.Cmd(nil, "accept --comment 'here we go' 101")
253+ _, _ = user.CmdAddr(nil, actualAddr, "accept --comment 'here we go' 101")
254
255 time.Sleep(100 * time.Millisecond)
256
257@@ -365,17 +572,7 @@ func (s UserSSH) Public() string {
258 return string(ssh.MarshalAuthorizedKey(pubkey))
259 }
260
261-func (s UserSSH) MustCmd(patch []byte, cmd string) string {
262- res, err := s.Cmd(patch, cmd)
263- if err != nil {
264- panic(err)
265- }
266- return res
267-}
268-
269-func (s UserSSH) Cmd(patch []byte, cmd string) (string, error) {
270- host := "localhost:2222"
271-
272+func (s UserSSH) CmdAddr(patch []byte, addr string, cmd string) (string, error) {
273 config := &ssh.ClientConfig{
274 User: s.username,
275 Auth: []ssh.AuthMethod{
276@@ -384,7 +581,7 @@ func (s UserSSH) Cmd(patch []byte, cmd string) (string, error) {
277 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
278 }
279
280- client, err := ssh.Dial("tcp", host, config)
281+ client, err := ssh.Dial("tcp", addr, config)
282 if err != nil {
283 return "", err
284 }