repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pssh
Eric Bower  ·  2025-07-21

server.go

  1package pssh
  2
  3import (
  4	"context"
  5	"crypto/ed25519"
  6	"crypto/rand"
  7	"crypto/subtle"
  8	"encoding/base64"
  9	"encoding/pem"
 10	"errors"
 11	"fmt"
 12	"io"
 13	"log/slog"
 14	"net"
 15	"net/http"
 16	"os"
 17	"path"
 18	"strings"
 19	"sync"
 20	"time"
 21	"unicode/utf8"
 22
 23	"github.com/antoniomika/syncmap"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	"github.com/prometheus/client_golang/prometheus/promhttp"
 27	"golang.org/x/crypto/ssh"
 28)
 29
 30type SSHServerConn struct {
 31	Ctx        context.Context
 32	CancelFunc context.CancelFunc
 33	Logger     *slog.Logger
 34	Conn       *ssh.ServerConn
 35	SSHServer  *SSHServer
 36	Start      time.Time
 37
 38	mu sync.Mutex
 39}
 40
 41func (s *SSHServerConn) Context() context.Context {
 42	s.mu.Lock()
 43	defer s.mu.Unlock()
 44
 45	return s.Ctx
 46}
 47
 48func (sc *SSHServerConn) Close() error {
 49	sc.CancelFunc()
 50	return nil
 51}
 52
 53type SSHServerConnSession struct {
 54	ssh.Channel
 55	*SSHServerConn
 56
 57	Ctx        context.Context
 58	CancelFunc context.CancelFunc
 59
 60	pty   *Pty
 61	winch chan Window
 62
 63	mu sync.Mutex
 64}
 65
 66// Deadline implements context.Context.
 67func (s *SSHServerConnSession) Deadline() (deadline time.Time, ok bool) {
 68	s.mu.Lock()
 69	defer s.mu.Unlock()
 70
 71	return s.Ctx.Deadline()
 72}
 73
 74// Done implements context.Context.
 75func (s *SSHServerConnSession) Done() <-chan struct{} {
 76	s.mu.Lock()
 77	defer s.mu.Unlock()
 78
 79	return s.Ctx.Done()
 80}
 81
 82// Err implements context.Context.
 83func (s *SSHServerConnSession) Err() error {
 84	s.mu.Lock()
 85	defer s.mu.Unlock()
 86
 87	return s.Ctx.Err()
 88}
 89
 90// Value implements context.Context.
 91func (s *SSHServerConnSession) Value(key any) any {
 92	s.mu.Lock()
 93	defer s.mu.Unlock()
 94
 95	return s.Ctx.Value(key)
 96}
 97
 98// SetValue implements context.Context.
 99func (s *SSHServerConnSession) SetValue(key any, data any) {
100	s.mu.Lock()
101	defer s.mu.Unlock()
102
103	s.Ctx = context.WithValue(s.Ctx, key, data)
104}
105
106func (s *SSHServerConnSession) Context() context.Context {
107	s.mu.Lock()
108	defer s.mu.Unlock()
109
110	return s.Ctx
111}
112
113func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
114	return s.Conn.Permissions
115}
116
117func (s *SSHServerConnSession) User() string {
118	return s.Conn.User()
119}
120
121func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
122	key, ok := s.Conn.Permissions.Extensions["pubkey"]
123	if !ok {
124		return nil
125	}
126
127	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
128	if err != nil {
129		return nil
130	}
131	return pk
132}
133
134func (s *SSHServerConnSession) RemoteAddr() net.Addr {
135	return s.Conn.RemoteAddr()
136}
137
138func (s *SSHServerConnSession) Command() []string {
139	cmd, _ := s.Value("command").([]string)
140	return cmd
141}
142
143func (s *SSHServerConnSession) Close() error {
144	s.CancelFunc()
145	return s.Channel.Close()
146}
147
148func (s *SSHServerConnSession) Exit(code int) error {
149	status := struct{ Status uint32 }{uint32(code)}
150	_, err := s.SendRequest("exit-status", false, ssh.Marshal(&status))
151	return err
152}
153
154func (s *SSHServerConnSession) Fatal(err error) {
155	_, _ = fmt.Fprintln(s.Stderr(), err)
156	_, _ = fmt.Fprintf(s.Stderr(), "\r")
157	_ = s.Exit(1)
158	_ = s.Close()
159}
160
161func (s *SSHServerConnSession) Pty() (*Pty, <-chan Window, bool) {
162	s.mu.Lock()
163	defer s.mu.Unlock()
164
165	if s.pty == nil {
166		return nil, nil, false
167	}
168
169	return s.pty, s.winch, true
170}
171
172var _ context.Context = &SSHServerConnSession{}
173
174func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
175	defer func() {
176		_ = sc.Close()
177	}()
178
179	for {
180		select {
181		case <-sc.Context().Done():
182			return nil
183		case newChan, ok := <-chans:
184			if !ok {
185				return nil
186			}
187
188			sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
189			chanFunc, ok := sc.SSHServer.Config.ChannelMiddleware[newChan.ChannelType()]
190			if !ok {
191				sc.Logger.Info("no channel middleware for type", "type", newChan.ChannelType())
192				continue
193			}
194
195			go func() {
196				err := chanFunc(newChan, sc)
197				if err != nil {
198					sc.Logger.Error("channel middleware", "err", err)
199				}
200			}()
201		case req, ok := <-reqs:
202			if !ok {
203				return nil
204			}
205			sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
206			switch req.Type {
207			case "keepalive@openssh.com":
208				sc.Logger.Info("keepalive reply")
209				err := req.Reply(true, nil)
210				if err != nil {
211					sc.Logger.Error("keepalive", "err", err)
212				}
213			}
214		}
215	}
216}
217
218func NewSSHServerConn(
219	ctx context.Context,
220	logger *slog.Logger,
221	conn *ssh.ServerConn,
222	server *SSHServer,
223) *SSHServerConn {
224	if ctx == nil {
225		ctx = context.Background()
226	}
227
228	cancelCtx, cancelFunc := context.WithCancel(ctx)
229
230	if logger == nil {
231		logger = slog.Default()
232	}
233
234	return &SSHServerConn{
235		Ctx:        cancelCtx,
236		CancelFunc: cancelFunc,
237		Logger:     logger,
238		Conn:       conn,
239		SSHServer:  server,
240		Start:      time.Now(),
241	}
242}
243
244type SSHServerHandler func(*SSHServerConnSession) error
245type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
246type SSHServerChannelMiddleware func(ssh.NewChannel, *SSHServerConn) error
247
248type SSHServerConfig struct {
249	*ssh.ServerConfig
250	App                 string
251	ListenAddr          string
252	PromListenAddr      string
253	Middleware          []SSHServerMiddleware
254	SubsystemMiddleware []SSHServerMiddleware
255	ChannelMiddleware   map[string]SSHServerChannelMiddleware
256}
257
258type SSHServer struct {
259	Ctx        context.Context
260	CancelFunc context.CancelFunc
261	Logger     *slog.Logger
262	Config     *SSHServerConfig
263	Listener   net.Listener
264	Conns      *syncmap.Map[string, *SSHServerConn]
265
266	SessionsCreated  *prometheus.CounterVec
267	SessionsFinished *prometheus.CounterVec
268	SessionsDuration *prometheus.CounterVec
269}
270
271func (s *SSHServer) ListenAndServe() error {
272	if s.Config.PromListenAddr != "" {
273		s.SessionsCreated = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
274			Name: "pssh_sessions_created_total",
275			Help: "The total number of sessions created",
276			ConstLabels: prometheus.Labels{
277				"app": s.Config.App,
278			},
279		}, []string{"command"})
280
281		s.SessionsFinished = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
282			Name: "pssh_sessions_finished_total",
283			Help: "The total number of sessions created",
284			ConstLabels: prometheus.Labels{
285				"app": s.Config.App,
286			},
287		}, []string{"command"})
288
289		s.SessionsDuration = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
290			Name: "pssh_sessions_duration_seconds",
291			Help: "The total sessions duration in seconds",
292			ConstLabels: prometheus.Labels{
293				"app": s.Config.App,
294			},
295		}, []string{"command"})
296
297		go func() {
298			mux := http.NewServeMux()
299			mux.Handle("/metrics", promhttp.Handler())
300
301			srv := &http.Server{Addr: s.Config.PromListenAddr, Handler: mux}
302
303			go func() {
304				<-s.Ctx.Done()
305				s.Logger.Info("Prometheus server shutting down")
306				_ = srv.Close()
307			}()
308
309			s.Logger.Info("Starting Prometheus server", "addr", s.Config.PromListenAddr)
310
311			err := srv.ListenAndServe()
312			if err != nil {
313				if errors.Is(err, http.ErrServerClosed) {
314					s.Logger.Info("Prometheus server shut down")
315					return
316				}
317
318				s.Logger.Error("Prometheus serve error", "err", err)
319				panic(err)
320			}
321		}()
322	}
323
324	listen, err := net.Listen("tcp", s.Config.ListenAddr)
325	if err != nil {
326		return err
327	}
328
329	s.Listener = listen
330	defer func() {
331		_ = s.Listener.Close()
332	}()
333
334	go func() {
335		<-s.Ctx.Done()
336		_ = s.Close()
337	}()
338
339	var retErr error
340
341	for {
342		conn, err := s.Listener.Accept()
343		if err != nil {
344			s.Logger.Error("accept", "err", err)
345			if errors.Is(err, net.ErrClosed) {
346				retErr = err
347				break
348			}
349			continue
350		}
351
352		go func() {
353			if err := s.HandleConn(conn); err != nil && !errors.Is(err, io.EOF) {
354				s.Logger.Error("Error handling connection", "err", err, "remoteAddr", conn.RemoteAddr().String())
355			}
356		}()
357	}
358
359	if errors.Is(retErr, net.ErrClosed) {
360		return nil
361	}
362
363	return retErr
364}
365
366func (s *SSHServer) HandleConn(conn net.Conn) error {
367	defer func() {
368		_ = conn.Close()
369	}()
370
371	sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
372	if err != nil {
373		return err
374	}
375
376	newLogger := s.Logger.With(
377		"remoteAddr", conn.RemoteAddr().String(),
378		"sshUser", sshConn.User(),
379	)
380
381	if pubKey, ok := sshConn.Permissions.Extensions["pubkey"]; ok {
382		newLogger = newLogger.With("pubkey", pubKey)
383	}
384
385	newConn := NewSSHServerConn(
386		s.Ctx,
387		newLogger,
388		sshConn,
389		s,
390	)
391
392	s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
393
394	err = newConn.Handle(chans, reqs)
395
396	s.Conns.Delete(sshConn.RemoteAddr().String())
397
398	return err
399}
400
401func (s *SSHServer) Close() error {
402	s.CancelFunc()
403	return s.Listener.Close()
404}
405
406func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
407	if ctx == nil {
408		ctx = context.Background()
409	}
410
411	cancelCtx, cancelFunc := context.WithCancel(ctx)
412
413	if logger == nil {
414		logger = slog.Default()
415	}
416
417	if config == nil {
418		config = &SSHServerConfig{}
419	}
420
421	if config.ChannelMiddleware == nil {
422		config.ChannelMiddleware = map[string]SSHServerChannelMiddleware{}
423	}
424
425	if _, ok := config.ChannelMiddleware["session"]; !ok {
426		config.ChannelMiddleware["session"] = func(newChan ssh.NewChannel, sc *SSHServerConn) error {
427			channel, requests, err := newChan.Accept()
428			if err != nil {
429				sc.Logger.Error("accept session channel", "err", err)
430				return err
431			}
432
433			ctx, cancelFunc := context.WithCancel(sc.Ctx)
434
435			sesh := &SSHServerConnSession{
436				Channel:       channel,
437				SSHServerConn: sc,
438				Ctx:           ctx,
439				CancelFunc:    cancelFunc,
440			}
441
442			for {
443				select {
444				case <-sesh.Done():
445					return nil
446				case req, ok := <-requests:
447					if !ok {
448						return nil
449					}
450
451					go func() {
452						sc.Logger.Info("new session request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
453						switch req.Type {
454						case "subsystem":
455							if len(sc.SSHServer.Config.SubsystemMiddleware) == 0 {
456								err := req.Reply(false, nil)
457								if err != nil {
458									sc.Logger.Error("subsystem reply", "err", err)
459								}
460
461								err = sc.Close()
462								if err != nil {
463									sc.Logger.Error("subsystem close", "err", err)
464								}
465
466								sesh.Fatal(err)
467								return
468							}
469
470							h := func(*SSHServerConnSession) error { return nil }
471							for _, m := range sc.SSHServer.Config.SubsystemMiddleware {
472								h = m(h)
473							}
474
475							err := req.Reply(true, nil)
476							if err != nil {
477								sc.Logger.Error("subsystem reply", "err", err)
478								sesh.Fatal(err)
479								return
480							}
481
482							if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
483								sc.Logger.Error("subsystem middleware", "err", err)
484								sesh.Fatal(err)
485								return
486							}
487
488							err = sesh.Exit(0)
489							if err != nil {
490								sc.Logger.Error("subsystem exit", "err", err)
491							}
492
493							err = sesh.Close()
494							if err != nil {
495								sc.Logger.Error("subsystem close", "err", err)
496							}
497						case "shell", "exec":
498							if len(sc.SSHServer.Config.Middleware) == 0 {
499								err := req.Reply(false, nil)
500								if err != nil {
501									sc.Logger.Error("shell/exec reply", "err", err)
502								}
503								sesh.Fatal(err)
504								return
505							}
506
507							command := "shell"
508
509							if len(req.Payload) > 0 {
510								var payload = struct{ Value string }{}
511								err := ssh.Unmarshal(req.Payload, &payload)
512								if err != nil {
513									sc.Logger.Error("shell/exec unmarshal", "err", err)
514									sesh.Fatal(err)
515									return
516								}
517
518								command = payload.Value
519
520								sesh.SetValue("command", strings.Fields(payload.Value))
521							}
522
523							if !utf8.ValidString(command) {
524								command = base64.StdEncoding.EncodeToString([]byte(command))
525							}
526
527							if sc.SSHServer.Config.PromListenAddr != "" {
528								sc.SSHServer.SessionsCreated.WithLabelValues(command).Inc()
529								defer func() {
530									sc.SSHServer.SessionsFinished.WithLabelValues(command).Inc()
531									sc.SSHServer.SessionsDuration.WithLabelValues(command).Add(time.Since(sc.Start).Seconds())
532								}()
533							}
534
535							h := func(*SSHServerConnSession) error { return nil }
536							for _, m := range sc.SSHServer.Config.Middleware {
537								h = m(h)
538							}
539
540							err = req.Reply(true, nil)
541							if err != nil {
542								sc.Logger.Error("shell/exec reply", "err", err)
543								sesh.Fatal(err)
544								return
545							}
546
547							if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
548								sc.Logger.Error("exec middleware", "err", err)
549								sesh.Fatal(err)
550								return
551							}
552
553							err = sesh.Exit(0)
554							if err != nil {
555								sc.Logger.Error("subsystem exit", "err", err)
556							}
557
558							err = sesh.Close()
559							if err != nil {
560								sc.Logger.Error("subsystem close", "err", err)
561							}
562						case "pty-req":
563							sesh.mu.Lock()
564							found := sesh.pty != nil
565							sesh.mu.Unlock()
566							if found {
567								err := req.Reply(false, nil)
568								if err != nil {
569									sc.Logger.Error("pty-req reply", "err", err)
570								}
571								return
572							}
573
574							ptyReq, ok := parsePtyRequest(req.Payload)
575							if !ok {
576								err := req.Reply(false, nil)
577								if err != nil {
578									sc.Logger.Error("pty-req reply", "err", err)
579								}
580								return
581							}
582
583							sesh.mu.Lock()
584							sesh.pty = &ptyReq
585							sesh.winch = make(chan Window, 1)
586							sesh.mu.Unlock()
587
588							sesh.winch <- ptyReq.Window
589							err := req.Reply(ok, nil)
590							if err != nil {
591								sc.Logger.Error("pty-req reply", "err", err)
592							}
593						case "window-change":
594							sesh.mu.Lock()
595							found := sesh.pty != nil
596							sesh.mu.Unlock()
597
598							if !found {
599								err := req.Reply(false, nil)
600								if err != nil {
601									sc.Logger.Error("pty-req reply", "err", err)
602								}
603								return
604							}
605
606							win, ok := parseWinchRequest(req.Payload)
607							if ok {
608								sesh.mu.Lock()
609								sesh.pty.Window = win
610								sesh.winch <- win
611								sesh.mu.Unlock()
612							}
613
614							err := req.Reply(ok, nil)
615							if err != nil {
616								sc.Logger.Error("window-change reply", "err", err)
617							}
618						}
619					}()
620				}
621			}
622		}
623	}
624
625	server := &SSHServer{
626		Ctx:        cancelCtx,
627		CancelFunc: cancelFunc,
628		Logger:     logger,
629		Config:     config,
630		Conns:      syncmap.New[string, *SSHServerConn](),
631	}
632
633	return server
634}
635
636type PubKeyAuthHandler func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
637
638func NewSSHServerWithConfig(
639	ctx context.Context,
640	logger *slog.Logger,
641	app, host, port, promPort, hostKey string,
642	pubKeyAuthHandler PubKeyAuthHandler,
643	middleware, subsystemMiddleware []SSHServerMiddleware,
644	channelMiddleware map[string]SSHServerChannelMiddleware,
645) (*SSHServer, error) {
646	server := NewSSHServer(ctx, logger, &SSHServerConfig{
647		App:        app,
648		ListenAddr: fmt.Sprintf("%s:%s", host, port),
649		ServerConfig: &ssh.ServerConfig{
650			PublicKeyCallback: pubKeyAuthHandler,
651		},
652		Middleware:          middleware,
653		SubsystemMiddleware: subsystemMiddleware,
654		ChannelMiddleware:   channelMiddleware,
655	})
656
657	if promPort != "" {
658		server.Config.PromListenAddr = fmt.Sprintf("%s:%s", host, promPort)
659	}
660
661	if hostKey != "" {
662		pemBytes, err := os.ReadFile(hostKey)
663		if err != nil {
664			logger.Error("failed to read private key file", "error", err)
665			if !os.IsNotExist(err) {
666				return nil, err
667			}
668
669			logger.Info("generating new private key")
670
671			pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
672			if err != nil {
673				logger.Error("failed to generate private key", "error", err)
674				return nil, err
675			}
676
677			privb, err := ssh.MarshalPrivateKey(privKey, "")
678			if err != nil {
679				logger.Error("failed to marshal private key", "error", err)
680				return nil, err
681			}
682
683			block := &pem.Block{
684				Type:  "OPENSSH PRIVATE KEY",
685				Bytes: privb.Bytes,
686			}
687
688			if err = os.MkdirAll(path.Dir(hostKey), 0700); err != nil {
689				logger.Error("failed to create ssh_data directory", "error", err)
690				return nil, err
691			}
692
693			pemBytes = pem.EncodeToMemory(block)
694
695			if err = os.WriteFile(hostKey, pemBytes, 0600); err != nil {
696				logger.Error("failed to write private key", "error", err)
697				return nil, err
698			}
699
700			sshPubKey, err := ssh.NewPublicKey(pubKey)
701			if err != nil {
702				logger.Error("failed to create public key", "error", err)
703				return nil, err
704			}
705
706			pubb := ssh.MarshalAuthorizedKey(sshPubKey)
707			if err = os.WriteFile(fmt.Sprintf("%s.pub", hostKey), pubb, 0600); err != nil {
708				logger.Error("failed to write public key", "error", err)
709				return nil, err
710			}
711		}
712
713		signer, err := ssh.ParsePrivateKey(pemBytes)
714		if err != nil {
715			logger.Error("failed to parse private key", "error", err)
716			return nil, err
717		}
718
719		server.Config.AddHostKey(signer)
720	}
721
722	return server, nil
723}
724
725func KeysEqual(a, b ssh.PublicKey) bool {
726	if a == nil || b == nil {
727		return false
728	}
729
730	am := a.Marshal()
731	bm := b.Marshal()
732	return (len(am) == len(bm) && subtle.ConstantTimeCompare(am, bm) == 1)
733}