- commit
- b20dbf9
- parent
- 59eee85
- author
- Antonio Mika
- date
- 2025-03-03 15:16:49 -0500 EST
Started work on a crypto/ssh based ssh server
3 files changed,
+484,
-0
+29,
-0
1@@ -0,0 +1,29 @@
2+package main
3+
4+import (
5+ "context"
6+ "log/slog"
7+
8+ "github.com/picosh/pico/shared"
9+)
10+
11+func main() {
12+ // Initialize the logger
13+ logger := slog.Default()
14+
15+ ctx, cancel := context.WithCancel(context.Background())
16+ defer cancel()
17+
18+ // Create a new SSH server
19+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{
20+ ListenAddr: "localhost:2222",
21+ })
22+
23+ err := server.ListenAndServe()
24+ if err != nil {
25+ logger.Error("failed to start SSH server", "error", err)
26+ return
27+ }
28+
29+ logger.Info("SSH server started successfully")
30+}
1@@ -0,0 +1,180 @@
2+package shared
3+
4+import (
5+ "context"
6+ "errors"
7+ "log/slog"
8+ "net"
9+
10+ "github.com/antoniomika/syncmap"
11+ "golang.org/x/crypto/ssh"
12+)
13+
14+type SSHServerConn struct {
15+ Ctx context.Context
16+ CancelFunc context.CancelFunc
17+ Logger *slog.Logger
18+ Conn *ssh.ServerConn
19+ SSHServer *SSHServer
20+}
21+
22+func (sc *SSHServerConn) Close() error {
23+ sc.CancelFunc()
24+ return nil
25+}
26+
27+func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
28+ defer sc.Close()
29+
30+ for {
31+ select {
32+ case <-sc.Ctx.Done():
33+ return nil
34+ case newChan := <-chans:
35+ sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
36+ case req := <-reqs:
37+ sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
38+ }
39+ }
40+}
41+
42+func NewSSHServerConn(
43+ ctx context.Context,
44+ logger *slog.Logger,
45+ conn *ssh.ServerConn,
46+ server *SSHServer,
47+) *SSHServerConn {
48+ if ctx == nil {
49+ ctx = context.Background()
50+ }
51+
52+ cancelCtx, cancelFunc := context.WithCancel(ctx)
53+
54+ if logger == nil {
55+ logger = slog.Default()
56+ }
57+
58+ return &SSHServerConn{
59+ Ctx: cancelCtx,
60+ CancelFunc: cancelFunc,
61+ Logger: logger,
62+ Conn: conn,
63+ SSHServer: server,
64+ }
65+}
66+
67+type SSHServerMiddleware func(func(ssh.Session) error) func(ssh.Session) error
68+
69+type SSHServerConfig struct {
70+ *ssh.ServerConfig
71+ ListenAddr string
72+ SessionMiddleware []SSHServerMiddleware
73+ SubsystemMiddleware []SSHServerMiddleware
74+}
75+
76+type SSHServer struct {
77+ Ctx context.Context
78+ CancelFunc context.CancelFunc
79+ Logger *slog.Logger
80+ Config *SSHServerConfig
81+ Listener net.Listener
82+ Conns *syncmap.Map[string, *SSHServerConn]
83+}
84+
85+func (s *SSHServer) ListenAndServe() error {
86+ listen, err := net.Listen("tcp", s.Config.ListenAddr)
87+ if err != nil {
88+ return err
89+ }
90+
91+ s.Listener = listen
92+ defer s.Listener.Close()
93+
94+ go func() {
95+ <-s.Ctx.Done()
96+ s.Close()
97+ }()
98+
99+ var retErr error
100+
101+ for {
102+ conn, err := s.Listener.Accept()
103+ if err != nil {
104+ s.Logger.Error("accept", "err", err)
105+ if errors.Is(err, net.ErrClosed) {
106+ retErr = err
107+ break
108+ }
109+ continue
110+ }
111+
112+ go func() {
113+ if err := s.HandleConn(conn); err != nil {
114+ s.Logger.Error("handle conn", "err", err)
115+ }
116+ }()
117+ }
118+
119+ return retErr
120+}
121+
122+func (s *SSHServer) HandleConn(conn net.Conn) error {
123+ defer conn.Close()
124+
125+ sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
126+ if err != nil {
127+ return err
128+ }
129+
130+ newLogger := s.Logger.With(
131+ "remoteAddr", conn.RemoteAddr().String(),
132+ "user", sshConn.User(),
133+ "pubkey", sshConn.Permissions.Extensions["pubkey"],
134+ )
135+
136+ newConn := NewSSHServerConn(
137+ s.Ctx,
138+ newLogger,
139+ sshConn,
140+ s,
141+ )
142+
143+ s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
144+
145+ err = newConn.Handle(chans, reqs)
146+
147+ s.Conns.Delete(sshConn.RemoteAddr().String())
148+
149+ return err
150+}
151+
152+func (s *SSHServer) Close() error {
153+ s.CancelFunc()
154+ return s.Listener.Close()
155+}
156+
157+func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
158+ if ctx == nil {
159+ ctx = context.Background()
160+ }
161+
162+ cancelCtx, cancelFunc := context.WithCancel(ctx)
163+
164+ if logger == nil {
165+ logger = slog.Default()
166+ }
167+
168+ if config == nil {
169+ config = &SSHServerConfig{}
170+ }
171+
172+ server := &SSHServer{
173+ Ctx: cancelCtx,
174+ CancelFunc: cancelFunc,
175+ Logger: logger,
176+ Config: config,
177+ Conns: syncmap.New[string, *SSHServerConn](),
178+ }
179+
180+ return server
181+}
1@@ -0,0 +1,275 @@
2+package shared_test
3+
4+import (
5+ "context"
6+ "errors"
7+ "log/slog"
8+ "net"
9+ "testing"
10+ "time"
11+
12+ "github.com/picosh/pico/shared"
13+ "golang.org/x/crypto/ssh"
14+)
15+
16+func TestNewSSHServer(t *testing.T) {
17+ ctx := context.Background()
18+ logger := slog.Default()
19+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
20+
21+ if server == nil {
22+ t.Fatal("expected non-nil server")
23+ }
24+
25+ if server.Ctx == nil {
26+ t.Error("expected non-nil context")
27+ }
28+
29+ if server.Logger == nil {
30+ t.Error("expected non-nil logger")
31+ }
32+
33+ if server.Config == nil {
34+ t.Error("expected non-nil config")
35+ }
36+
37+ if server.Conns == nil {
38+ t.Error("expected non-nil connections map")
39+ }
40+}
41+
42+func TestNewSSHServerConn(t *testing.T) {
43+ ctx := context.Background()
44+ logger := slog.Default()
45+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
46+ conn := &ssh.ServerConn{}
47+
48+ serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
49+
50+ if serverConn == nil {
51+ t.Fatal("expected non-nil server connection")
52+ }
53+
54+ if serverConn.Ctx == nil {
55+ t.Error("expected non-nil context")
56+ }
57+
58+ if serverConn.Logger == nil {
59+ t.Error("expected non-nil logger")
60+ }
61+
62+ if serverConn.Conn != conn {
63+ t.Error("expected conn to match")
64+ }
65+
66+ if serverConn.SSHServer != server {
67+ t.Error("expected server to match")
68+ }
69+}
70+
71+func TestSSHServerConnClose(t *testing.T) {
72+ ctx := context.Background()
73+ logger := slog.Default()
74+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
75+ conn := &ssh.ServerConn{}
76+
77+ serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
78+ err := serverConn.Close()
79+
80+ if err != nil {
81+ t.Errorf("unexpected error: %v", err)
82+ }
83+
84+ // Should be canceled after close
85+ select {
86+ case <-serverConn.Ctx.Done():
87+ // Context was canceled as expected
88+ default:
89+ t.Error("context was not canceled after Close()")
90+ }
91+}
92+
93+func TestSSHServerClose(t *testing.T) {
94+ ctx := context.Background()
95+ logger := slog.Default()
96+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
97+
98+ // Create a mock listener to test Close()
99+ listener, err := net.Listen("tcp", "127.0.0.1:0")
100+ if err != nil {
101+ t.Fatalf("failed to create listener: %v", err)
102+ }
103+
104+ server.Listener = listener
105+ err = server.Close()
106+
107+ if err != nil {
108+ t.Errorf("unexpected error: %v", err)
109+ }
110+
111+ // Should be canceled after close
112+ select {
113+ case <-server.Ctx.Done():
114+ // Context was canceled as expected
115+ default:
116+ t.Error("context was not canceled after Close()")
117+ }
118+}
119+
120+func TestSSHServerNilParams(t *testing.T) {
121+ // Test with nil context and logger
122+ //nolint:staticcheck // SA1012 ignores nil check
123+ //lint:ignore SA1012 ignores nil check
124+ server := shared.NewSSHServer(nil, nil, nil)
125+
126+ if server == nil {
127+ t.Fatal("expected non-nil server")
128+ }
129+
130+ if server.Ctx == nil {
131+ t.Error("expected non-nil context even when nil is passed")
132+ }
133+
134+ if server.Logger == nil {
135+ t.Error("expected non-nil logger even when nil is passed")
136+ }
137+
138+ // Test with nil context and logger for connection
139+ //nolint:staticcheck // SA1012 ignores nil check
140+ //lint:ignore SA1012 ignores nil check
141+ conn := shared.NewSSHServerConn(nil, nil, &ssh.ServerConn{}, server)
142+
143+ if conn == nil {
144+ t.Fatal("expected non-nil server connection")
145+ }
146+
147+ if conn.Ctx == nil {
148+ t.Error("expected non-nil context even when nil is passed")
149+ }
150+
151+ if conn.Logger == nil {
152+ t.Error("expected non-nil logger even when nil is passed")
153+ }
154+}
155+
156+func TestSSHServerHandleConn(t *testing.T) {
157+ ctx, cancel := context.WithCancel(context.Background())
158+ defer cancel()
159+ logger := slog.Default()
160+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
161+
162+ // Setup a basic SSH server config
163+ config := &ssh.ServerConfig{
164+ NoClientAuth: true,
165+ }
166+
167+ server.Config.ServerConfig = config
168+
169+ // Create a mock connection
170+ client, server_conn := net.Pipe()
171+ defer client.Close()
172+
173+ // Start HandleConn in a goroutine
174+ errChan := make(chan error, 1)
175+ go func() {
176+ errChan <- server.HandleConn(server_conn)
177+ }()
178+
179+ // Configure SSH client
180+ clientConfig := &ssh.ClientConfig{
181+ User: "testuser",
182+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
183+ }
184+
185+ // Try to establish SSH connection
186+ _, _, _, err := ssh.NewClientConn(client, "", clientConfig)
187+
188+ // It should fail since we're using a pipe and not a proper SSH handshake
189+ if err == nil {
190+ t.Error("expected SSH handshake to fail with test pipe")
191+ }
192+
193+ // Close connections to ensure HandleConn returns
194+ client.Close()
195+ server_conn.Close()
196+
197+ // Wait for HandleConn to return
198+ select {
199+ case <-errChan:
200+ // Expected HandleConn to return
201+ case <-time.After(2 * time.Second):
202+ t.Error("HandleConn did not return after connection closed")
203+ }
204+}
205+
206+func TestSSHServerListenAndServe(t *testing.T) {
207+ ctx, cancel := context.WithCancel(context.Background())
208+ defer cancel()
209+ logger := slog.Default()
210+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
211+
212+ config := &ssh.ServerConfig{
213+ NoClientAuth: true,
214+ }
215+
216+ // Set a random port
217+ port := "127.0.0.1:0"
218+ server.Config.ListenAddr = port
219+ server.Config.ServerConfig = config
220+
221+ // Start server in a goroutine
222+ errChan := make(chan error, 1)
223+ go func() {
224+ err := server.ListenAndServe()
225+ errChan <- err
226+ }()
227+
228+ // Wait a bit for the server to start
229+ time.Sleep(100 * time.Millisecond)
230+
231+ // Trigger cancellation to stop the server
232+ cancel()
233+
234+ // Wait for server to stop
235+ select {
236+ case err := <-errChan:
237+ if err != nil && !errors.Is(err, net.ErrClosed) {
238+ t.Errorf("unexpected error: %v", err)
239+ }
240+ case <-time.After(2 * time.Second):
241+ t.Error("server did not shut down in time")
242+ }
243+}
244+
245+func TestSSHServerConnHandle(t *testing.T) {
246+ ctx, cancel := context.WithCancel(context.Background())
247+ defer cancel()
248+ logger := slog.Default()
249+ server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
250+ conn := &ssh.ServerConn{}
251+
252+ serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
253+
254+ // Create channels for testing
255+ chans := make(chan ssh.NewChannel)
256+ reqs := make(chan *ssh.Request)
257+
258+ // Start handle in a goroutine
259+ errChan := make(chan error, 1)
260+ go func() {
261+ errChan <- serverConn.Handle(chans, reqs)
262+ }()
263+
264+ // Ensure handle returns when context is canceled
265+ cancel()
266+
267+ // Wait for handle to return
268+ select {
269+ case err := <-errChan:
270+ if err != nil {
271+ t.Errorf("unexpected error: %v", err)
272+ }
273+ case <-time.After(2 * time.Second):
274+ t.Error("Handle did not return after context canceled")
275+ }
276+}