repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pssh
Eric Bower  ·  2026-02-26

server.go

  1package pssh
  2
  3import (
  4	"bytes"
  5	"context"
  6	"crypto/ed25519"
  7	"crypto/rand"
  8	"crypto/subtle"
  9	"encoding/base64"
 10	"encoding/pem"
 11	"errors"
 12	"fmt"
 13	"io"
 14	"log/slog"
 15	"net"
 16	"net/http"
 17	"os"
 18	"path"
 19	"sync"
 20	"time"
 21	"unicode/utf8"
 22
 23	"github.com/antoniomika/syncmap"
 24	"github.com/go-andiamo/splitter"
 25	"github.com/prometheus/client_golang/prometheus"
 26	"github.com/prometheus/client_golang/prometheus/promauto"
 27	"github.com/prometheus/client_golang/prometheus/promhttp"
 28	"golang.org/x/crypto/ssh"
 29)
 30
 31type SSHServerConn struct {
 32	Ctx        context.Context
 33	CancelFunc context.CancelFunc
 34	Logger     *slog.Logger
 35	Conn       *ssh.ServerConn
 36	SSHServer  *SSHServer
 37	Start      time.Time
 38
 39	mu sync.Mutex
 40}
 41
 42func (s *SSHServerConn) Context() context.Context {
 43	s.mu.Lock()
 44	defer s.mu.Unlock()
 45
 46	return s.Ctx
 47}
 48
 49func (sc *SSHServerConn) Close() error {
 50	sc.CancelFunc()
 51	return nil
 52}
 53
 54type SSHServerConnSession struct {
 55	ssh.Channel
 56	*SSHServerConn
 57
 58	Ctx        context.Context
 59	CancelFunc context.CancelFunc
 60
 61	pty   *Pty
 62	winch chan Window
 63
 64	mu sync.Mutex
 65}
 66
 67// Deadline implements context.Context.
 68func (s *SSHServerConnSession) Deadline() (deadline time.Time, ok bool) {
 69	s.mu.Lock()
 70	defer s.mu.Unlock()
 71
 72	return s.Ctx.Deadline()
 73}
 74
 75// Done implements context.Context.
 76func (s *SSHServerConnSession) Done() <-chan struct{} {
 77	s.mu.Lock()
 78	defer s.mu.Unlock()
 79
 80	return s.Ctx.Done()
 81}
 82
 83// Err implements context.Context.
 84func (s *SSHServerConnSession) Err() error {
 85	s.mu.Lock()
 86	defer s.mu.Unlock()
 87
 88	return s.Ctx.Err()
 89}
 90
 91// Value implements context.Context.
 92func (s *SSHServerConnSession) Value(key any) any {
 93	s.mu.Lock()
 94	defer s.mu.Unlock()
 95
 96	return s.Ctx.Value(key)
 97}
 98
 99// SetValue implements context.Context.
100func (s *SSHServerConnSession) SetValue(key any, data any) {
101	s.mu.Lock()
102	defer s.mu.Unlock()
103
104	s.Ctx = context.WithValue(s.Ctx, key, data)
105}
106
107func (s *SSHServerConnSession) Context() context.Context {
108	s.mu.Lock()
109	defer s.mu.Unlock()
110
111	return s.Ctx
112}
113
114func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
115	return s.Conn.Permissions
116}
117
118func (s *SSHServerConnSession) User() string {
119	return s.Conn.User()
120}
121
122func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
123	key, ok := s.Conn.Permissions.Extensions["pubkey"]
124	if !ok {
125		return nil
126	}
127
128	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
129	if err != nil {
130		return nil
131	}
132	return pk
133}
134
135func (s *SSHServerConnSession) RemoteAddr() net.Addr {
136	return s.Conn.RemoteAddr()
137}
138
139func (s *SSHServerConnSession) Command() []string {
140	cmd, _ := s.Value("command").([]string)
141	return cmd
142}
143
144func (s *SSHServerConnSession) Close() error {
145	s.CancelFunc()
146	return s.Channel.Close()
147}
148
149func (s *SSHServerConnSession) Exit(code int) error {
150	status := struct{ Status uint32 }{uint32(code)}
151	_, err := s.SendRequest("exit-status", false, ssh.Marshal(&status))
152	return err
153}
154
155func (sesh *SSHServerConnSession) Errorf(f string, v ...interface{}) {
156	_, _ = fmt.Fprintf(sesh.Stderr(), f, v...)
157}
158
159func (sesh *SSHServerConnSession) Errorln(v ...interface{}) {
160	_, _ = fmt.Fprintln(sesh.Stderr(), v...)
161}
162
163func (sesh *SSHServerConnSession) Printf(f string, v ...interface{}) {
164	_, _ = fmt.Fprintf(sesh, f, v...)
165}
166
167func (sesh *SSHServerConnSession) Println(v ...interface{}) {
168	_, _ = fmt.Fprintln(sesh, v...)
169}
170
171func (s *SSHServerConnSession) Fatal(err error) {
172	_, _ = fmt.Fprintln(s.Stderr(), err)
173	_, _ = fmt.Fprintf(s.Stderr(), "\r")
174	_ = s.Exit(1)
175	_ = s.Close()
176}
177
178func (s *SSHServerConnSession) Pty() (*Pty, <-chan Window, bool) {
179	s.mu.Lock()
180	defer s.mu.Unlock()
181
182	if s.pty == nil {
183		return nil, nil, false
184	}
185
186	return s.pty, s.winch, true
187}
188
189// Write overrides the embedded Channel's Write to normalize line endings when PTY is allocated.
190func (s *SSHServerConnSession) Write(p []byte) (n int, err error) {
191	s.mu.Lock()
192	hasPty := s.pty != nil
193	s.mu.Unlock()
194
195	if !hasPty {
196		// No PTY, write as-is
197		return s.Channel.Write(p)
198	}
199
200	// When PTY is active, ensure every \n is preceded by \r.
201	// This ensures the cursor returns to column 0 before the newline.
202	var buf bytes.Buffer
203	for i := 0; i < len(p); i++ {
204		if p[i] == '\n' {
205			// Check if preceded by \r
206			if i == 0 || p[i-1] != '\r' {
207				buf.WriteByte('\r')
208			}
209			buf.WriteByte('\n')
210		} else {
211			buf.WriteByte(p[i])
212		}
213	}
214
215	normalized := buf.Bytes()
216	written, err := s.Channel.Write(normalized)
217
218	// Return the count based on original data length, not normalized
219	if written > len(p) {
220		written = len(p)
221	}
222	return written, err
223}
224
225var _ context.Context = &SSHServerConnSession{}
226
227func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
228	defer func() {
229		_ = sc.Close()
230	}()
231
232	for {
233		select {
234		case <-sc.Context().Done():
235			return nil
236		case newChan, ok := <-chans:
237			if !ok {
238				return nil
239			}
240
241			sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
242			chanFunc, ok := sc.SSHServer.Config.ChannelMiddleware[newChan.ChannelType()]
243			if !ok {
244				sc.Logger.Info("no channel middleware for type", "type", newChan.ChannelType())
245				continue
246			}
247
248			go func() {
249				err := chanFunc(newChan, sc)
250				if err != nil {
251					sc.Logger.Error("channel middleware", "err", err)
252				}
253			}()
254		case req, ok := <-reqs:
255			if !ok {
256				return nil
257			}
258			sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
259			switch req.Type {
260			case "keepalive@openssh.com":
261				sc.Logger.Info("keepalive reply")
262				err := req.Reply(true, nil)
263				if err != nil {
264					sc.Logger.Error("keepalive", "err", err)
265				}
266			}
267		}
268	}
269}
270
271func NewSSHServerConn(
272	ctx context.Context,
273	logger *slog.Logger,
274	conn *ssh.ServerConn,
275	server *SSHServer,
276) *SSHServerConn {
277	if ctx == nil {
278		ctx = context.Background()
279	}
280
281	cancelCtx, cancelFunc := context.WithCancel(ctx)
282
283	if logger == nil {
284		logger = slog.Default()
285	}
286
287	return &SSHServerConn{
288		Ctx:        cancelCtx,
289		CancelFunc: cancelFunc,
290		Logger:     logger,
291		Conn:       conn,
292		SSHServer:  server,
293		Start:      time.Now(),
294	}
295}
296
297type SSHServerHandler func(*SSHServerConnSession) error
298type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
299type SSHServerChannelMiddleware func(ssh.NewChannel, *SSHServerConn) error
300
301type SSHServerConfig struct {
302	*ssh.ServerConfig
303	App                 string
304	ListenAddr          string
305	PromListenAddr      string
306	Middleware          []SSHServerMiddleware
307	SubsystemMiddleware []SSHServerMiddleware
308	ChannelMiddleware   map[string]SSHServerChannelMiddleware
309}
310
311type SSHServer struct {
312	Ctx        context.Context
313	CancelFunc context.CancelFunc
314	Logger     *slog.Logger
315	Config     *SSHServerConfig
316	Listener   net.Listener
317	Conns      *syncmap.Map[string, *SSHServerConn]
318
319	SessionsCreated  *prometheus.CounterVec
320	SessionsFinished *prometheus.CounterVec
321	SessionsDuration *prometheus.CounterVec
322
323	Mu sync.Mutex
324}
325
326func (s *SSHServer) ListenAndServe() error {
327	if s.Config.PromListenAddr != "" {
328		s.SessionsCreated = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
329			Name: "pssh_sessions_created_total",
330			Help: "The total number of sessions created",
331			ConstLabels: prometheus.Labels{
332				"app": s.Config.App,
333			},
334		}, []string{"command"})
335
336		s.SessionsFinished = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
337			Name: "pssh_sessions_finished_total",
338			Help: "The total number of sessions created",
339			ConstLabels: prometheus.Labels{
340				"app": s.Config.App,
341			},
342		}, []string{"command"})
343
344		s.SessionsDuration = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
345			Name: "pssh_sessions_duration_seconds",
346			Help: "The total sessions duration in seconds",
347			ConstLabels: prometheus.Labels{
348				"app": s.Config.App,
349			},
350		}, []string{"command"})
351
352		go func() {
353			mux := http.NewServeMux()
354			mux.Handle("/metrics", promhttp.Handler())
355
356			srv := &http.Server{Addr: s.Config.PromListenAddr, Handler: mux}
357
358			go func() {
359				<-s.Ctx.Done()
360				s.Logger.Info("Prometheus server shutting down")
361				_ = srv.Close()
362			}()
363
364			s.Logger.Info("Starting Prometheus server", "addr", s.Config.PromListenAddr)
365
366			err := srv.ListenAndServe()
367			if err != nil {
368				if errors.Is(err, http.ErrServerClosed) {
369					s.Logger.Info("Prometheus server shut down")
370					return
371				}
372
373				s.Logger.Error("Prometheus serve error", "err", err)
374				panic(err)
375			}
376		}()
377	}
378
379	listen, err := net.Listen("tcp", s.Config.ListenAddr)
380	if err != nil {
381		return err
382	}
383
384	s.Mu.Lock()
385	s.Listener = listen
386	s.Mu.Unlock()
387	defer func() {
388		_ = s.Listener.Close()
389	}()
390
391	go func() {
392		<-s.Ctx.Done()
393		_ = s.Close()
394	}()
395
396	var retErr error
397
398	for {
399		conn, err := s.Listener.Accept()
400		if err != nil {
401			s.Logger.Error("accept", "err", err)
402			if errors.Is(err, net.ErrClosed) {
403				retErr = err
404				break
405			}
406			continue
407		}
408
409		go func() {
410			if err := s.HandleConn(conn); err != nil && !errors.Is(err, io.EOF) {
411				s.Logger.Error("Error handling connection", "err", err, "remoteAddr", conn.RemoteAddr().String())
412			}
413		}()
414	}
415
416	if errors.Is(retErr, net.ErrClosed) {
417		return nil
418	}
419
420	return retErr
421}
422
423func (s *SSHServer) HandleConn(conn net.Conn) error {
424	defer func() {
425		_ = conn.Close()
426	}()
427
428	sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
429	if err != nil {
430		return err
431	}
432
433	newLogger := s.Logger.With(
434		"remoteAddr", conn.RemoteAddr().String(),
435		"sshUser", sshConn.User(),
436	)
437
438	if pubKey, ok := sshConn.Permissions.Extensions["pubkey"]; ok {
439		newLogger = newLogger.With("pubkey", pubKey)
440	}
441
442	newConn := NewSSHServerConn(
443		s.Ctx,
444		newLogger,
445		sshConn,
446		s,
447	)
448
449	s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
450
451	err = newConn.Handle(chans, reqs)
452
453	s.Conns.Delete(sshConn.RemoteAddr().String())
454
455	return err
456}
457
458func (s *SSHServer) Close() error {
459	s.CancelFunc()
460	return s.Listener.Close()
461}
462
463func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
464	if ctx == nil {
465		ctx = context.Background()
466	}
467
468	cancelCtx, cancelFunc := context.WithCancel(ctx)
469
470	if logger == nil {
471		logger = slog.Default()
472	}
473
474	if config == nil {
475		config = &SSHServerConfig{}
476	}
477
478	if config.ChannelMiddleware == nil {
479		config.ChannelMiddleware = map[string]SSHServerChannelMiddleware{}
480	}
481
482	if _, ok := config.ChannelMiddleware["session"]; !ok {
483		config.ChannelMiddleware["session"] = func(newChan ssh.NewChannel, sc *SSHServerConn) error {
484			channel, requests, err := newChan.Accept()
485			if err != nil {
486				sc.Logger.Error("accept session channel", "err", err)
487				return err
488			}
489
490			ctx, cancelFunc := context.WithCancel(sc.Ctx)
491
492			sesh := &SSHServerConnSession{
493				Channel:       channel,
494				SSHServerConn: sc,
495				Ctx:           ctx,
496				CancelFunc:    cancelFunc,
497			}
498
499			for {
500				select {
501				case <-sesh.Done():
502					return nil
503				case req, ok := <-requests:
504					if !ok {
505						return nil
506					}
507
508					go func() {
509						sc.Logger.Info("new session request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
510						switch req.Type {
511						case "subsystem":
512							if len(sc.SSHServer.Config.SubsystemMiddleware) == 0 {
513								err := req.Reply(false, nil)
514								if err != nil {
515									sc.Logger.Error("subsystem reply", "err", err)
516								}
517
518								err = sc.Close()
519								if err != nil {
520									sc.Logger.Error("subsystem close", "err", err)
521								}
522
523								sesh.Fatal(err)
524								return
525							}
526
527							h := func(*SSHServerConnSession) error { return nil }
528							for _, m := range sc.SSHServer.Config.SubsystemMiddleware {
529								h = m(h)
530							}
531
532							err := req.Reply(true, nil)
533							if err != nil {
534								sc.Logger.Error("subsystem reply", "err", err)
535								sesh.Fatal(err)
536								return
537							}
538
539							if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
540								sc.Logger.Error("subsystem middleware", "err", err)
541								sesh.Fatal(err)
542								return
543							}
544
545							err = sesh.Exit(0)
546							if err != nil {
547								sc.Logger.Error("subsystem exit", "err", err)
548							}
549
550							err = sesh.Close()
551							if err != nil {
552								sc.Logger.Error("subsystem close", "err", err)
553							}
554						case "shell", "exec":
555							if len(sc.SSHServer.Config.Middleware) == 0 {
556								err := req.Reply(false, nil)
557								if err != nil {
558									sc.Logger.Error("shell/exec reply", "err", err)
559								}
560								sesh.Fatal(err)
561								return
562							}
563
564							command := "shell"
565
566							if len(req.Payload) > 0 {
567								var payload = struct{ Value string }{}
568								err := ssh.Unmarshal(req.Payload, &payload)
569								if err != nil {
570									sc.Logger.Error("shell/exec unmarshal", "err", err)
571									sesh.Fatal(err)
572									return
573								}
574
575								commaSplitter, _ := splitter.NewSplitter(
576									' ',
577									splitter.DoubleQuotes,
578									splitter.SingleQuotes,
579								)
580								command = payload.Value
581								cmdSlice, _ := commaSplitter.Split(command, splitter.StripQuotes)
582								sesh.SetValue("command", cmdSlice)
583							}
584
585							if !utf8.ValidString(command) {
586								command = base64.StdEncoding.EncodeToString([]byte(command))
587							}
588
589							if sc.SSHServer.Config.PromListenAddr != "" {
590								sc.SSHServer.SessionsCreated.WithLabelValues(command).Inc()
591								defer func() {
592									sc.SSHServer.SessionsFinished.WithLabelValues(command).Inc()
593									sc.SSHServer.SessionsDuration.WithLabelValues(command).Add(time.Since(sc.Start).Seconds())
594								}()
595							}
596
597							h := func(*SSHServerConnSession) error { return nil }
598							for _, m := range sc.SSHServer.Config.Middleware {
599								h = m(h)
600							}
601
602							err = req.Reply(true, nil)
603							if err != nil {
604								sc.Logger.Error("shell/exec reply", "err", err)
605								sesh.Fatal(err)
606								return
607							}
608
609							if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
610								sc.Logger.Error("exec middleware", "err", err)
611								sesh.Fatal(err)
612								return
613							}
614
615							err = sesh.Exit(0)
616							if err != nil {
617								sc.Logger.Error("subsystem exit", "err", err)
618							}
619
620							err = sesh.Close()
621							if err != nil {
622								sc.Logger.Error("subsystem close", "err", err)
623							}
624						case "pty-req":
625							sesh.mu.Lock()
626							found := sesh.pty != nil
627							sesh.mu.Unlock()
628							if found {
629								err := req.Reply(false, nil)
630								if err != nil {
631									sc.Logger.Error("pty-req reply", "err", err)
632								}
633								return
634							}
635
636							ptyReq, ok := parsePtyRequest(req.Payload)
637							if !ok {
638								err := req.Reply(false, nil)
639								if err != nil {
640									sc.Logger.Error("pty-req reply", "err", err)
641								}
642								return
643							}
644
645							sesh.mu.Lock()
646							sesh.pty = &ptyReq
647							sesh.winch = make(chan Window, 1)
648							sesh.mu.Unlock()
649
650							sesh.winch <- ptyReq.Window
651							err := req.Reply(ok, nil)
652							if err != nil {
653								sc.Logger.Error("pty-req reply", "err", err)
654							}
655						case "window-change":
656							sesh.mu.Lock()
657							found := sesh.pty != nil
658							sesh.mu.Unlock()
659
660							if !found {
661								err := req.Reply(false, nil)
662								if err != nil {
663									sc.Logger.Error("pty-req reply", "err", err)
664								}
665								return
666							}
667
668							win, ok := parseWinchRequest(req.Payload)
669							if ok {
670								sesh.mu.Lock()
671								sesh.pty.Window = win
672								sesh.winch <- win
673								sesh.mu.Unlock()
674							}
675
676							err := req.Reply(ok, nil)
677							if err != nil {
678								sc.Logger.Error("window-change reply", "err", err)
679							}
680						}
681					}()
682				}
683			}
684		}
685	}
686
687	server := &SSHServer{
688		Ctx:        cancelCtx,
689		CancelFunc: cancelFunc,
690		Logger:     logger,
691		Config:     config,
692		Conns:      syncmap.New[string, *SSHServerConn](),
693	}
694
695	return server
696}
697
698type PubKeyAuthHandler func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
699
700func NewSSHServerWithConfig(
701	ctx context.Context,
702	logger *slog.Logger,
703	app, host, port, promPort, hostKey string,
704	pubKeyAuthHandler PubKeyAuthHandler,
705	middleware, subsystemMiddleware []SSHServerMiddleware,
706	channelMiddleware map[string]SSHServerChannelMiddleware,
707) (*SSHServer, error) {
708	server := NewSSHServer(ctx, logger, &SSHServerConfig{
709		App:        app,
710		ListenAddr: fmt.Sprintf("%s:%s", host, port),
711		ServerConfig: &ssh.ServerConfig{
712			PublicKeyCallback: pubKeyAuthHandler,
713		},
714		Middleware:          middleware,
715		SubsystemMiddleware: subsystemMiddleware,
716		ChannelMiddleware:   channelMiddleware,
717	})
718
719	if promPort != "" {
720		server.Config.PromListenAddr = fmt.Sprintf("%s:%s", host, promPort)
721	}
722
723	if hostKey != "" {
724		pemBytes, err := os.ReadFile(hostKey)
725		if err != nil {
726			logger.Error("failed to read private key file", "error", err)
727			if !os.IsNotExist(err) {
728				return nil, err
729			}
730
731			logger.Info("generating new private key")
732
733			pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
734			if err != nil {
735				logger.Error("failed to generate private key", "error", err)
736				return nil, err
737			}
738
739			privb, err := ssh.MarshalPrivateKey(privKey, "")
740			if err != nil {
741				logger.Error("failed to marshal private key", "error", err)
742				return nil, err
743			}
744
745			block := &pem.Block{
746				Type:  "OPENSSH PRIVATE KEY",
747				Bytes: privb.Bytes,
748			}
749
750			if err = os.MkdirAll(path.Dir(hostKey), 0700); err != nil {
751				logger.Error("failed to create ssh_data directory", "error", err)
752				return nil, err
753			}
754
755			pemBytes = pem.EncodeToMemory(block)
756
757			if err = os.WriteFile(hostKey, pemBytes, 0600); err != nil {
758				logger.Error("failed to write private key", "error", err)
759				return nil, err
760			}
761
762			sshPubKey, err := ssh.NewPublicKey(pubKey)
763			if err != nil {
764				logger.Error("failed to create public key", "error", err)
765				return nil, err
766			}
767
768			pubb := ssh.MarshalAuthorizedKey(sshPubKey)
769			if err = os.WriteFile(fmt.Sprintf("%s.pub", hostKey), pubb, 0600); err != nil {
770				logger.Error("failed to write public key", "error", err)
771				return nil, err
772			}
773		}
774
775		signer, err := ssh.ParsePrivateKey(pemBytes)
776		if err != nil {
777			logger.Error("failed to parse private key", "error", err)
778			return nil, err
779		}
780
781		server.Config.AddHostKey(signer)
782	}
783
784	return server, nil
785}
786
787func KeysEqual(a, b ssh.PublicKey) bool {
788	if a == nil || b == nil {
789		return false
790	}
791
792	am := a.Marshal()
793	bm := b.Marshal()
794	return (len(am) == len(bm) && subtle.ConstantTimeCompare(am, bm) == 1)
795}