repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

commit
b20dbf9
parent
59eee85
author
Antonio Mika
date
2025-03-03 15:16:49 -0500 EST
Started work on a crypto/ssh based ssh server
3 files changed,  +484, -0
A cmd/pgs/test/main.go
+29, -0
 1@@ -0,0 +1,29 @@
 2+package main
 3+
 4+import (
 5+	"context"
 6+	"log/slog"
 7+
 8+	"github.com/picosh/pico/shared"
 9+)
10+
11+func main() {
12+	// Initialize the logger
13+	logger := slog.Default()
14+
15+	ctx, cancel := context.WithCancel(context.Background())
16+	defer cancel()
17+
18+	// Create a new SSH server
19+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{
20+		ListenAddr: "localhost:2222",
21+	})
22+
23+	err := server.ListenAndServe()
24+	if err != nil {
25+		logger.Error("failed to start SSH server", "error", err)
26+		return
27+	}
28+
29+	logger.Info("SSH server started successfully")
30+}
A shared/sshServer.go
+180, -0
  1@@ -0,0 +1,180 @@
  2+package shared
  3+
  4+import (
  5+	"context"
  6+	"errors"
  7+	"log/slog"
  8+	"net"
  9+
 10+	"github.com/antoniomika/syncmap"
 11+	"golang.org/x/crypto/ssh"
 12+)
 13+
 14+type SSHServerConn struct {
 15+	Ctx        context.Context
 16+	CancelFunc context.CancelFunc
 17+	Logger     *slog.Logger
 18+	Conn       *ssh.ServerConn
 19+	SSHServer  *SSHServer
 20+}
 21+
 22+func (sc *SSHServerConn) Close() error {
 23+	sc.CancelFunc()
 24+	return nil
 25+}
 26+
 27+func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
 28+	defer sc.Close()
 29+
 30+	for {
 31+		select {
 32+		case <-sc.Ctx.Done():
 33+			return nil
 34+		case newChan := <-chans:
 35+			sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
 36+		case req := <-reqs:
 37+			sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
 38+		}
 39+	}
 40+}
 41+
 42+func NewSSHServerConn(
 43+	ctx context.Context,
 44+	logger *slog.Logger,
 45+	conn *ssh.ServerConn,
 46+	server *SSHServer,
 47+) *SSHServerConn {
 48+	if ctx == nil {
 49+		ctx = context.Background()
 50+	}
 51+
 52+	cancelCtx, cancelFunc := context.WithCancel(ctx)
 53+
 54+	if logger == nil {
 55+		logger = slog.Default()
 56+	}
 57+
 58+	return &SSHServerConn{
 59+		Ctx:        cancelCtx,
 60+		CancelFunc: cancelFunc,
 61+		Logger:     logger,
 62+		Conn:       conn,
 63+		SSHServer:  server,
 64+	}
 65+}
 66+
 67+type SSHServerMiddleware func(func(ssh.Session) error) func(ssh.Session) error
 68+
 69+type SSHServerConfig struct {
 70+	*ssh.ServerConfig
 71+	ListenAddr          string
 72+	SessionMiddleware   []SSHServerMiddleware
 73+	SubsystemMiddleware []SSHServerMiddleware
 74+}
 75+
 76+type SSHServer struct {
 77+	Ctx        context.Context
 78+	CancelFunc context.CancelFunc
 79+	Logger     *slog.Logger
 80+	Config     *SSHServerConfig
 81+	Listener   net.Listener
 82+	Conns      *syncmap.Map[string, *SSHServerConn]
 83+}
 84+
 85+func (s *SSHServer) ListenAndServe() error {
 86+	listen, err := net.Listen("tcp", s.Config.ListenAddr)
 87+	if err != nil {
 88+		return err
 89+	}
 90+
 91+	s.Listener = listen
 92+	defer s.Listener.Close()
 93+
 94+	go func() {
 95+		<-s.Ctx.Done()
 96+		s.Close()
 97+	}()
 98+
 99+	var retErr error
100+
101+	for {
102+		conn, err := s.Listener.Accept()
103+		if err != nil {
104+			s.Logger.Error("accept", "err", err)
105+			if errors.Is(err, net.ErrClosed) {
106+				retErr = err
107+				break
108+			}
109+			continue
110+		}
111+
112+		go func() {
113+			if err := s.HandleConn(conn); err != nil {
114+				s.Logger.Error("handle conn", "err", err)
115+			}
116+		}()
117+	}
118+
119+	return retErr
120+}
121+
122+func (s *SSHServer) HandleConn(conn net.Conn) error {
123+	defer conn.Close()
124+
125+	sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
126+	if err != nil {
127+		return err
128+	}
129+
130+	newLogger := s.Logger.With(
131+		"remoteAddr", conn.RemoteAddr().String(),
132+		"user", sshConn.User(),
133+		"pubkey", sshConn.Permissions.Extensions["pubkey"],
134+	)
135+
136+	newConn := NewSSHServerConn(
137+		s.Ctx,
138+		newLogger,
139+		sshConn,
140+		s,
141+	)
142+
143+	s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
144+
145+	err = newConn.Handle(chans, reqs)
146+
147+	s.Conns.Delete(sshConn.RemoteAddr().String())
148+
149+	return err
150+}
151+
152+func (s *SSHServer) Close() error {
153+	s.CancelFunc()
154+	return s.Listener.Close()
155+}
156+
157+func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
158+	if ctx == nil {
159+		ctx = context.Background()
160+	}
161+
162+	cancelCtx, cancelFunc := context.WithCancel(ctx)
163+
164+	if logger == nil {
165+		logger = slog.Default()
166+	}
167+
168+	if config == nil {
169+		config = &SSHServerConfig{}
170+	}
171+
172+	server := &SSHServer{
173+		Ctx:        cancelCtx,
174+		CancelFunc: cancelFunc,
175+		Logger:     logger,
176+		Config:     config,
177+		Conns:      syncmap.New[string, *SSHServerConn](),
178+	}
179+
180+	return server
181+}
A shared/sshServer_test.go
+275, -0
  1@@ -0,0 +1,275 @@
  2+package shared_test
  3+
  4+import (
  5+	"context"
  6+	"errors"
  7+	"log/slog"
  8+	"net"
  9+	"testing"
 10+	"time"
 11+
 12+	"github.com/picosh/pico/shared"
 13+	"golang.org/x/crypto/ssh"
 14+)
 15+
 16+func TestNewSSHServer(t *testing.T) {
 17+	ctx := context.Background()
 18+	logger := slog.Default()
 19+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
 20+
 21+	if server == nil {
 22+		t.Fatal("expected non-nil server")
 23+	}
 24+
 25+	if server.Ctx == nil {
 26+		t.Error("expected non-nil context")
 27+	}
 28+
 29+	if server.Logger == nil {
 30+		t.Error("expected non-nil logger")
 31+	}
 32+
 33+	if server.Config == nil {
 34+		t.Error("expected non-nil config")
 35+	}
 36+
 37+	if server.Conns == nil {
 38+		t.Error("expected non-nil connections map")
 39+	}
 40+}
 41+
 42+func TestNewSSHServerConn(t *testing.T) {
 43+	ctx := context.Background()
 44+	logger := slog.Default()
 45+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
 46+	conn := &ssh.ServerConn{}
 47+
 48+	serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
 49+
 50+	if serverConn == nil {
 51+		t.Fatal("expected non-nil server connection")
 52+	}
 53+
 54+	if serverConn.Ctx == nil {
 55+		t.Error("expected non-nil context")
 56+	}
 57+
 58+	if serverConn.Logger == nil {
 59+		t.Error("expected non-nil logger")
 60+	}
 61+
 62+	if serverConn.Conn != conn {
 63+		t.Error("expected conn to match")
 64+	}
 65+
 66+	if serverConn.SSHServer != server {
 67+		t.Error("expected server to match")
 68+	}
 69+}
 70+
 71+func TestSSHServerConnClose(t *testing.T) {
 72+	ctx := context.Background()
 73+	logger := slog.Default()
 74+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
 75+	conn := &ssh.ServerConn{}
 76+
 77+	serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
 78+	err := serverConn.Close()
 79+
 80+	if err != nil {
 81+		t.Errorf("unexpected error: %v", err)
 82+	}
 83+
 84+	// Should be canceled after close
 85+	select {
 86+	case <-serverConn.Ctx.Done():
 87+		// Context was canceled as expected
 88+	default:
 89+		t.Error("context was not canceled after Close()")
 90+	}
 91+}
 92+
 93+func TestSSHServerClose(t *testing.T) {
 94+	ctx := context.Background()
 95+	logger := slog.Default()
 96+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
 97+
 98+	// Create a mock listener to test Close()
 99+	listener, err := net.Listen("tcp", "127.0.0.1:0")
100+	if err != nil {
101+		t.Fatalf("failed to create listener: %v", err)
102+	}
103+
104+	server.Listener = listener
105+	err = server.Close()
106+
107+	if err != nil {
108+		t.Errorf("unexpected error: %v", err)
109+	}
110+
111+	// Should be canceled after close
112+	select {
113+	case <-server.Ctx.Done():
114+		// Context was canceled as expected
115+	default:
116+		t.Error("context was not canceled after Close()")
117+	}
118+}
119+
120+func TestSSHServerNilParams(t *testing.T) {
121+	// Test with nil context and logger
122+	//nolint:staticcheck // SA1012 ignores nil check
123+	//lint:ignore SA1012 ignores nil check
124+	server := shared.NewSSHServer(nil, nil, nil)
125+
126+	if server == nil {
127+		t.Fatal("expected non-nil server")
128+	}
129+
130+	if server.Ctx == nil {
131+		t.Error("expected non-nil context even when nil is passed")
132+	}
133+
134+	if server.Logger == nil {
135+		t.Error("expected non-nil logger even when nil is passed")
136+	}
137+
138+	// Test with nil context and logger for connection
139+	//nolint:staticcheck // SA1012 ignores nil check
140+	//lint:ignore SA1012 ignores nil check
141+	conn := shared.NewSSHServerConn(nil, nil, &ssh.ServerConn{}, server)
142+
143+	if conn == nil {
144+		t.Fatal("expected non-nil server connection")
145+	}
146+
147+	if conn.Ctx == nil {
148+		t.Error("expected non-nil context even when nil is passed")
149+	}
150+
151+	if conn.Logger == nil {
152+		t.Error("expected non-nil logger even when nil is passed")
153+	}
154+}
155+
156+func TestSSHServerHandleConn(t *testing.T) {
157+	ctx, cancel := context.WithCancel(context.Background())
158+	defer cancel()
159+	logger := slog.Default()
160+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
161+
162+	// Setup a basic SSH server config
163+	config := &ssh.ServerConfig{
164+		NoClientAuth: true,
165+	}
166+
167+	server.Config.ServerConfig = config
168+
169+	// Create a mock connection
170+	client, server_conn := net.Pipe()
171+	defer client.Close()
172+
173+	// Start HandleConn in a goroutine
174+	errChan := make(chan error, 1)
175+	go func() {
176+		errChan <- server.HandleConn(server_conn)
177+	}()
178+
179+	// Configure SSH client
180+	clientConfig := &ssh.ClientConfig{
181+		User:            "testuser",
182+		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
183+	}
184+
185+	// Try to establish SSH connection
186+	_, _, _, err := ssh.NewClientConn(client, "", clientConfig)
187+
188+	// It should fail since we're using a pipe and not a proper SSH handshake
189+	if err == nil {
190+		t.Error("expected SSH handshake to fail with test pipe")
191+	}
192+
193+	// Close connections to ensure HandleConn returns
194+	client.Close()
195+	server_conn.Close()
196+
197+	// Wait for HandleConn to return
198+	select {
199+	case <-errChan:
200+		// Expected HandleConn to return
201+	case <-time.After(2 * time.Second):
202+		t.Error("HandleConn did not return after connection closed")
203+	}
204+}
205+
206+func TestSSHServerListenAndServe(t *testing.T) {
207+	ctx, cancel := context.WithCancel(context.Background())
208+	defer cancel()
209+	logger := slog.Default()
210+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
211+
212+	config := &ssh.ServerConfig{
213+		NoClientAuth: true,
214+	}
215+
216+	// Set a random port
217+	port := "127.0.0.1:0"
218+	server.Config.ListenAddr = port
219+	server.Config.ServerConfig = config
220+
221+	// Start server in a goroutine
222+	errChan := make(chan error, 1)
223+	go func() {
224+		err := server.ListenAndServe()
225+		errChan <- err
226+	}()
227+
228+	// Wait a bit for the server to start
229+	time.Sleep(100 * time.Millisecond)
230+
231+	// Trigger cancellation to stop the server
232+	cancel()
233+
234+	// Wait for server to stop
235+	select {
236+	case err := <-errChan:
237+		if err != nil && !errors.Is(err, net.ErrClosed) {
238+			t.Errorf("unexpected error: %v", err)
239+		}
240+	case <-time.After(2 * time.Second):
241+		t.Error("server did not shut down in time")
242+	}
243+}
244+
245+func TestSSHServerConnHandle(t *testing.T) {
246+	ctx, cancel := context.WithCancel(context.Background())
247+	defer cancel()
248+	logger := slog.Default()
249+	server := shared.NewSSHServer(ctx, logger, &shared.SSHServerConfig{})
250+	conn := &ssh.ServerConn{}
251+
252+	serverConn := shared.NewSSHServerConn(ctx, logger, conn, server)
253+
254+	// Create channels for testing
255+	chans := make(chan ssh.NewChannel)
256+	reqs := make(chan *ssh.Request)
257+
258+	// Start handle in a goroutine
259+	errChan := make(chan error, 1)
260+	go func() {
261+		errChan <- serverConn.Handle(chans, reqs)
262+	}()
263+
264+	// Ensure handle returns when context is canceled
265+	cancel()
266+
267+	// Wait for handle to return
268+	select {
269+	case err := <-errChan:
270+		if err != nil {
271+			t.Errorf("unexpected error: %v", err)
272+		}
273+	case <-time.After(2 * time.Second):
274+		t.Error("Handle did not return after context canceled")
275+	}
276+}