repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pssh
Antonio Mika  ยท  2025-04-10

server.go

  1package pssh
  2
  3import (
  4	"context"
  5	"crypto/ed25519"
  6	"crypto/rand"
  7	"crypto/subtle"
  8	"encoding/base64"
  9	"encoding/pem"
 10	"errors"
 11	"fmt"
 12	"io"
 13	"log/slog"
 14	"net"
 15	"net/http"
 16	"os"
 17	"path"
 18	"strings"
 19	"sync"
 20	"time"
 21	"unicode/utf8"
 22
 23	"github.com/antoniomika/syncmap"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	"github.com/prometheus/client_golang/prometheus/promhttp"
 27	"golang.org/x/crypto/ssh"
 28)
 29
 30type SSHServerConn struct {
 31	Ctx        context.Context
 32	CancelFunc context.CancelFunc
 33	Logger     *slog.Logger
 34	Conn       *ssh.ServerConn
 35	SSHServer  *SSHServer
 36	Start      time.Time
 37
 38	mu sync.Mutex
 39}
 40
 41func (s *SSHServerConn) Context() context.Context {
 42	s.mu.Lock()
 43	defer s.mu.Unlock()
 44
 45	return s.Ctx
 46}
 47
 48func (sc *SSHServerConn) Close() error {
 49	sc.CancelFunc()
 50	return nil
 51}
 52
 53type SSHServerConnSession struct {
 54	ssh.Channel
 55	*SSHServerConn
 56
 57	Ctx        context.Context
 58	CancelFunc context.CancelFunc
 59
 60	pty   *Pty
 61	winch chan Window
 62
 63	mu sync.Mutex
 64}
 65
 66// Deadline implements context.Context.
 67func (s *SSHServerConnSession) Deadline() (deadline time.Time, ok bool) {
 68	s.mu.Lock()
 69	defer s.mu.Unlock()
 70
 71	return s.Ctx.Deadline()
 72}
 73
 74// Done implements context.Context.
 75func (s *SSHServerConnSession) Done() <-chan struct{} {
 76	s.mu.Lock()
 77	defer s.mu.Unlock()
 78
 79	return s.Ctx.Done()
 80}
 81
 82// Err implements context.Context.
 83func (s *SSHServerConnSession) Err() error {
 84	s.mu.Lock()
 85	defer s.mu.Unlock()
 86
 87	return s.Ctx.Err()
 88}
 89
 90// Value implements context.Context.
 91func (s *SSHServerConnSession) Value(key any) any {
 92	s.mu.Lock()
 93	defer s.mu.Unlock()
 94
 95	return s.Ctx.Value(key)
 96}
 97
 98// SetValue implements context.Context.
 99func (s *SSHServerConnSession) SetValue(key any, data any) {
100	s.mu.Lock()
101	defer s.mu.Unlock()
102
103	s.Ctx = context.WithValue(s.Ctx, key, data)
104}
105
106func (s *SSHServerConnSession) Context() context.Context {
107	s.mu.Lock()
108	defer s.mu.Unlock()
109
110	return s.Ctx
111}
112
113func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
114	return s.Conn.Permissions
115}
116
117func (s *SSHServerConnSession) User() string {
118	return s.Conn.User()
119}
120
121func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
122	key, ok := s.Conn.Permissions.Extensions["pubkey"]
123	if !ok {
124		return nil
125	}
126
127	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
128	if err != nil {
129		return nil
130	}
131	return pk
132}
133
134func (s *SSHServerConnSession) RemoteAddr() net.Addr {
135	return s.Conn.RemoteAddr()
136}
137
138func (s *SSHServerConnSession) Command() []string {
139	cmd, _ := s.Value("command").([]string)
140	return cmd
141}
142
143func (s *SSHServerConnSession) Close() error {
144	s.CancelFunc()
145	return s.Channel.Close()
146}
147
148func (s *SSHServerConnSession) Exit(code int) error {
149	status := struct{ Status uint32 }{uint32(code)}
150	_, err := s.Channel.SendRequest("exit-status", false, ssh.Marshal(&status))
151	return err
152}
153
154func (s *SSHServerConnSession) Fatal(err error) {
155	fmt.Fprintln(s.Stderr(), err)
156	fmt.Fprintf(s.Stderr(), "\r")
157	_ = s.Exit(1)
158	_ = s.Close()
159}
160
161func (s *SSHServerConnSession) Pty() (*Pty, <-chan Window, bool) {
162	s.mu.Lock()
163	defer s.mu.Unlock()
164
165	if s.pty == nil {
166		return nil, nil, false
167	}
168
169	return s.pty, s.winch, true
170}
171
172var _ context.Context = &SSHServerConnSession{}
173
174func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
175	defer sc.Close()
176
177	for {
178		select {
179		case <-sc.Context().Done():
180			return nil
181		case newChan, ok := <-chans:
182			if !ok {
183				return nil
184			}
185
186			sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
187			chanFunc, ok := sc.SSHServer.Config.ChannelMiddleware[newChan.ChannelType()]
188			if !ok {
189				sc.Logger.Info("no channel middleware for type", "type", newChan.ChannelType())
190				continue
191			}
192
193			go func() {
194				err := chanFunc(newChan, sc)
195				if err != nil {
196					sc.Logger.Error("channel middleware", "err", err)
197				}
198			}()
199		case req, ok := <-reqs:
200			if !ok {
201				return nil
202			}
203			sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
204		}
205	}
206}
207
208func NewSSHServerConn(
209	ctx context.Context,
210	logger *slog.Logger,
211	conn *ssh.ServerConn,
212	server *SSHServer,
213) *SSHServerConn {
214	if ctx == nil {
215		ctx = context.Background()
216	}
217
218	cancelCtx, cancelFunc := context.WithCancel(ctx)
219
220	if logger == nil {
221		logger = slog.Default()
222	}
223
224	return &SSHServerConn{
225		Ctx:        cancelCtx,
226		CancelFunc: cancelFunc,
227		Logger:     logger,
228		Conn:       conn,
229		SSHServer:  server,
230		Start:      time.Now(),
231	}
232}
233
234type SSHServerHandler func(*SSHServerConnSession) error
235type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
236type SSHServerChannelMiddleware func(ssh.NewChannel, *SSHServerConn) error
237
238type SSHServerConfig struct {
239	*ssh.ServerConfig
240	App                 string
241	ListenAddr          string
242	PromListenAddr      string
243	Middleware          []SSHServerMiddleware
244	SubsystemMiddleware []SSHServerMiddleware
245	ChannelMiddleware   map[string]SSHServerChannelMiddleware
246}
247
248type SSHServer struct {
249	Ctx        context.Context
250	CancelFunc context.CancelFunc
251	Logger     *slog.Logger
252	Config     *SSHServerConfig
253	Listener   net.Listener
254	Conns      *syncmap.Map[string, *SSHServerConn]
255
256	SessionsCreated  *prometheus.CounterVec
257	SessionsFinished *prometheus.CounterVec
258	SessionsDuration *prometheus.CounterVec
259}
260
261func (s *SSHServer) ListenAndServe() error {
262	if s.Config.PromListenAddr != "" {
263		s.SessionsCreated = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
264			Name: "pssh_sessions_created_total",
265			Help: "The total number of sessions created",
266			ConstLabels: prometheus.Labels{
267				"app": s.Config.App,
268			},
269		}, []string{"command"})
270
271		s.SessionsFinished = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
272			Name: "pssh_sessions_finished_total",
273			Help: "The total number of sessions created",
274			ConstLabels: prometheus.Labels{
275				"app": s.Config.App,
276			},
277		}, []string{"command"})
278
279		s.SessionsDuration = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
280			Name: "pssh_sessions_duration_seconds",
281			Help: "The total sessions duration in seconds",
282			ConstLabels: prometheus.Labels{
283				"app": s.Config.App,
284			},
285		}, []string{"command"})
286
287		go func() {
288			mux := http.NewServeMux()
289			mux.Handle("/metrics", promhttp.Handler())
290
291			srv := &http.Server{Addr: s.Config.PromListenAddr, Handler: mux}
292
293			go func() {
294				<-s.Ctx.Done()
295				s.Logger.Info("Prometheus server shutting down")
296				srv.Close()
297			}()
298
299			s.Logger.Info("Starting Prometheus server", "addr", s.Config.PromListenAddr)
300
301			err := srv.ListenAndServe()
302			if err != nil {
303				if errors.Is(err, http.ErrServerClosed) {
304					s.Logger.Info("Prometheus server shut down")
305					return
306				}
307
308				s.Logger.Error("Prometheus serve error", "err", err)
309				panic(err)
310			}
311		}()
312	}
313
314	listen, err := net.Listen("tcp", s.Config.ListenAddr)
315	if err != nil {
316		return err
317	}
318
319	s.Listener = listen
320	defer s.Listener.Close()
321
322	go func() {
323		<-s.Ctx.Done()
324		s.Close()
325	}()
326
327	var retErr error
328
329	for {
330		conn, err := s.Listener.Accept()
331		if err != nil {
332			s.Logger.Error("accept", "err", err)
333			if errors.Is(err, net.ErrClosed) {
334				retErr = err
335				break
336			}
337			continue
338		}
339
340		go func() {
341			if err := s.HandleConn(conn); err != nil && !errors.Is(err, io.EOF) {
342				s.Logger.Error("Error handling connection", "err", err, "remoteAddr", conn.RemoteAddr().String())
343			}
344		}()
345	}
346
347	if errors.Is(retErr, net.ErrClosed) {
348		return nil
349	}
350
351	return retErr
352}
353
354func (s *SSHServer) HandleConn(conn net.Conn) error {
355	defer conn.Close()
356
357	sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
358	if err != nil {
359		return err
360	}
361
362	newLogger := s.Logger.With(
363		"remoteAddr", conn.RemoteAddr().String(),
364		"user", sshConn.User(),
365		"pubkey", sshConn.Permissions.Extensions["pubkey"],
366	)
367
368	newConn := NewSSHServerConn(
369		s.Ctx,
370		newLogger,
371		sshConn,
372		s,
373	)
374
375	s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
376
377	err = newConn.Handle(chans, reqs)
378
379	s.Conns.Delete(sshConn.RemoteAddr().String())
380
381	return err
382}
383
384func (s *SSHServer) Close() error {
385	s.CancelFunc()
386	return s.Listener.Close()
387}
388
389func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
390	if ctx == nil {
391		ctx = context.Background()
392	}
393
394	cancelCtx, cancelFunc := context.WithCancel(ctx)
395
396	if logger == nil {
397		logger = slog.Default()
398	}
399
400	if config == nil {
401		config = &SSHServerConfig{}
402	}
403
404	if config.ChannelMiddleware == nil {
405		config.ChannelMiddleware = map[string]SSHServerChannelMiddleware{}
406	}
407
408	if _, ok := config.ChannelMiddleware["session"]; !ok {
409		config.ChannelMiddleware["session"] = func(newChan ssh.NewChannel, sc *SSHServerConn) error {
410			channel, requests, err := newChan.Accept()
411			if err != nil {
412				sc.Logger.Error("accept session channel", "err", err)
413				return err
414			}
415
416			ctx, cancelFunc := context.WithCancel(sc.Ctx)
417
418			sesh := &SSHServerConnSession{
419				Channel:       channel,
420				SSHServerConn: sc,
421				Ctx:           ctx,
422				CancelFunc:    cancelFunc,
423			}
424
425			for {
426				select {
427				case <-sesh.Done():
428					return nil
429				case req, ok := <-requests:
430					if !ok {
431						return nil
432					}
433
434					go func() {
435						sc.Logger.Info("new session request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
436						switch req.Type {
437						case "subsystem":
438							if len(sc.SSHServer.Config.SubsystemMiddleware) == 0 {
439								err := req.Reply(false, nil)
440								if err != nil {
441									sc.Logger.Error("subsystem reply", "err", err)
442								}
443
444								err = sc.Close()
445								if err != nil {
446									sc.Logger.Error("subsystem close", "err", err)
447								}
448
449								sesh.Fatal(err)
450								return
451							}
452
453							h := func(*SSHServerConnSession) error { return nil }
454							for _, m := range sc.SSHServer.Config.SubsystemMiddleware {
455								h = m(h)
456							}
457
458							err := req.Reply(true, nil)
459							if err != nil {
460								sc.Logger.Error("subsystem reply", "err", err)
461								sesh.Fatal(err)
462								return
463							}
464
465							if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
466								sc.Logger.Error("subsystem middleware", "err", err)
467								sesh.Fatal(err)
468								return
469							}
470
471							err = sesh.Exit(0)
472							if err != nil {
473								sc.Logger.Error("subsystem exit", "err", err)
474							}
475
476							err = sesh.Close()
477							if err != nil {
478								sc.Logger.Error("subsystem close", "err", err)
479							}
480						case "shell", "exec":
481							if len(sc.SSHServer.Config.Middleware) == 0 {
482								err := req.Reply(false, nil)
483								if err != nil {
484									sc.Logger.Error("shell/exec reply", "err", err)
485								}
486								sesh.Fatal(err)
487								return
488							}
489
490							command := "shell"
491
492							if len(req.Payload) > 0 {
493								var payload = struct{ Value string }{}
494								err := ssh.Unmarshal(req.Payload, &payload)
495								if err != nil {
496									sc.Logger.Error("shell/exec unmarshal", "err", err)
497									sesh.Fatal(err)
498									return
499								}
500
501								command = payload.Value
502
503								sesh.SetValue("command", strings.Fields(payload.Value))
504							}
505
506							if !utf8.ValidString(command) {
507								command = base64.StdEncoding.EncodeToString([]byte(command))
508							}
509
510							if sc.SSHServer.Config.PromListenAddr != "" {
511								sc.SSHServer.SessionsCreated.WithLabelValues(command).Inc()
512								defer func() {
513									sc.SSHServer.SessionsFinished.WithLabelValues(command).Inc()
514									sc.SSHServer.SessionsDuration.WithLabelValues(command).Add(time.Since(sc.Start).Seconds())
515								}()
516							}
517
518							h := func(*SSHServerConnSession) error { return nil }
519							for _, m := range sc.SSHServer.Config.Middleware {
520								h = m(h)
521							}
522
523							err = req.Reply(true, nil)
524							if err != nil {
525								sc.Logger.Error("shell/exec reply", "err", err)
526								sesh.Fatal(err)
527								return
528							}
529
530							if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
531								sc.Logger.Error("exec middleware", "err", err)
532								sesh.Fatal(err)
533								return
534							}
535
536							err = sesh.Exit(0)
537							if err != nil {
538								sc.Logger.Error("subsystem exit", "err", err)
539							}
540
541							err = sesh.Close()
542							if err != nil {
543								sc.Logger.Error("subsystem close", "err", err)
544							}
545						case "pty-req":
546							sesh.mu.Lock()
547							found := sesh.pty != nil
548							sesh.mu.Unlock()
549							if found {
550								err := req.Reply(false, nil)
551								if err != nil {
552									sc.Logger.Error("pty-req reply", "err", err)
553								}
554								return
555							}
556
557							ptyReq, ok := parsePtyRequest(req.Payload)
558							if !ok {
559								err := req.Reply(false, nil)
560								if err != nil {
561									sc.Logger.Error("pty-req reply", "err", err)
562								}
563								return
564							}
565
566							sesh.mu.Lock()
567							sesh.pty = &ptyReq
568							sesh.winch = make(chan Window, 1)
569							sesh.mu.Unlock()
570
571							sesh.winch <- ptyReq.Window
572							err := req.Reply(ok, nil)
573							if err != nil {
574								sc.Logger.Error("pty-req reply", "err", err)
575							}
576						case "window-change":
577							sesh.mu.Lock()
578							found := sesh.pty != nil
579							sesh.mu.Unlock()
580
581							if !found {
582								err := req.Reply(false, nil)
583								if err != nil {
584									sc.Logger.Error("pty-req reply", "err", err)
585								}
586								return
587							}
588
589							win, ok := parseWinchRequest(req.Payload)
590							if ok {
591								sesh.mu.Lock()
592								sesh.pty.Window = win
593								sesh.winch <- win
594								sesh.mu.Unlock()
595							}
596
597							err := req.Reply(ok, nil)
598							if err != nil {
599								sc.Logger.Error("window-change reply", "err", err)
600							}
601						}
602					}()
603				}
604			}
605		}
606	}
607
608	server := &SSHServer{
609		Ctx:        cancelCtx,
610		CancelFunc: cancelFunc,
611		Logger:     logger,
612		Config:     config,
613		Conns:      syncmap.New[string, *SSHServerConn](),
614	}
615
616	return server
617}
618
619type PubKeyAuthHandler func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
620
621func NewSSHServerWithConfig(
622	ctx context.Context,
623	logger *slog.Logger,
624	app, host, port, promPort, hostKey string,
625	pubKeyAuthHandler PubKeyAuthHandler,
626	middleware, subsystemMiddleware []SSHServerMiddleware,
627	channelMiddleware map[string]SSHServerChannelMiddleware,
628) (*SSHServer, error) {
629	server := NewSSHServer(ctx, logger, &SSHServerConfig{
630		App:        app,
631		ListenAddr: fmt.Sprintf("%s:%s", host, port),
632		ServerConfig: &ssh.ServerConfig{
633			PublicKeyCallback: pubKeyAuthHandler,
634		},
635		Middleware:          middleware,
636		SubsystemMiddleware: subsystemMiddleware,
637		ChannelMiddleware:   channelMiddleware,
638	})
639
640	if promPort != "" {
641		server.Config.PromListenAddr = fmt.Sprintf("%s:%s", host, promPort)
642	}
643
644	if hostKey != "" {
645		pemBytes, err := os.ReadFile(hostKey)
646		if err != nil {
647			logger.Error("failed to read private key file", "error", err)
648			if !os.IsNotExist(err) {
649				return nil, err
650			}
651
652			logger.Info("generating new private key")
653
654			pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
655			if err != nil {
656				logger.Error("failed to generate private key", "error", err)
657				return nil, err
658			}
659
660			privb, err := ssh.MarshalPrivateKey(privKey, "")
661			if err != nil {
662				logger.Error("failed to marshal private key", "error", err)
663				return nil, err
664			}
665
666			block := &pem.Block{
667				Type:  "OPENSSH PRIVATE KEY",
668				Bytes: privb.Bytes,
669			}
670
671			if err = os.MkdirAll(path.Dir(hostKey), 0700); err != nil {
672				logger.Error("failed to create ssh_data directory", "error", err)
673				return nil, err
674			}
675
676			pemBytes = pem.EncodeToMemory(block)
677
678			if err = os.WriteFile(hostKey, pemBytes, 0600); err != nil {
679				logger.Error("failed to write private key", "error", err)
680				return nil, err
681			}
682
683			sshPubKey, err := ssh.NewPublicKey(pubKey)
684			if err != nil {
685				logger.Error("failed to create public key", "error", err)
686				return nil, err
687			}
688
689			pubb := ssh.MarshalAuthorizedKey(sshPubKey)
690			if err = os.WriteFile(fmt.Sprintf("%s.pub", hostKey), pubb, 0600); err != nil {
691				logger.Error("failed to write public key", "error", err)
692				return nil, err
693			}
694		}
695
696		signer, err := ssh.ParsePrivateKey(pemBytes)
697		if err != nil {
698			logger.Error("failed to parse private key", "error", err)
699			return nil, err
700		}
701
702		server.Config.AddHostKey(signer)
703	}
704
705	return server, nil
706}
707
708func KeysEqual(a, b ssh.PublicKey) bool {
709	if a == nil || b == nil {
710		return false
711	}
712
713	am := a.Marshal()
714	bm := b.Marshal()
715	return (len(am) == len(bm) && subtle.ConstantTimeCompare(am, bm) == 1)
716}