Eric Bower
·
2026-02-26
server.go
1package pssh
2
3import (
4 "bytes"
5 "context"
6 "crypto/ed25519"
7 "crypto/rand"
8 "crypto/subtle"
9 "encoding/base64"
10 "encoding/pem"
11 "errors"
12 "fmt"
13 "io"
14 "log/slog"
15 "net"
16 "net/http"
17 "os"
18 "path"
19 "sync"
20 "time"
21 "unicode/utf8"
22
23 "github.com/antoniomika/syncmap"
24 "github.com/go-andiamo/splitter"
25 "github.com/prometheus/client_golang/prometheus"
26 "github.com/prometheus/client_golang/prometheus/promauto"
27 "github.com/prometheus/client_golang/prometheus/promhttp"
28 "golang.org/x/crypto/ssh"
29)
30
31type SSHServerConn struct {
32 Ctx context.Context
33 CancelFunc context.CancelFunc
34 Logger *slog.Logger
35 Conn *ssh.ServerConn
36 SSHServer *SSHServer
37 Start time.Time
38
39 mu sync.Mutex
40}
41
42func (s *SSHServerConn) Context() context.Context {
43 s.mu.Lock()
44 defer s.mu.Unlock()
45
46 return s.Ctx
47}
48
49func (sc *SSHServerConn) Close() error {
50 sc.CancelFunc()
51 return nil
52}
53
54type SSHServerConnSession struct {
55 ssh.Channel
56 *SSHServerConn
57
58 Ctx context.Context
59 CancelFunc context.CancelFunc
60
61 pty *Pty
62 winch chan Window
63
64 mu sync.Mutex
65}
66
67// Deadline implements context.Context.
68func (s *SSHServerConnSession) Deadline() (deadline time.Time, ok bool) {
69 s.mu.Lock()
70 defer s.mu.Unlock()
71
72 return s.Ctx.Deadline()
73}
74
75// Done implements context.Context.
76func (s *SSHServerConnSession) Done() <-chan struct{} {
77 s.mu.Lock()
78 defer s.mu.Unlock()
79
80 return s.Ctx.Done()
81}
82
83// Err implements context.Context.
84func (s *SSHServerConnSession) Err() error {
85 s.mu.Lock()
86 defer s.mu.Unlock()
87
88 return s.Ctx.Err()
89}
90
91// Value implements context.Context.
92func (s *SSHServerConnSession) Value(key any) any {
93 s.mu.Lock()
94 defer s.mu.Unlock()
95
96 return s.Ctx.Value(key)
97}
98
99// SetValue implements context.Context.
100func (s *SSHServerConnSession) SetValue(key any, data any) {
101 s.mu.Lock()
102 defer s.mu.Unlock()
103
104 s.Ctx = context.WithValue(s.Ctx, key, data)
105}
106
107func (s *SSHServerConnSession) Context() context.Context {
108 s.mu.Lock()
109 defer s.mu.Unlock()
110
111 return s.Ctx
112}
113
114func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
115 return s.Conn.Permissions
116}
117
118func (s *SSHServerConnSession) User() string {
119 return s.Conn.User()
120}
121
122func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
123 key, ok := s.Conn.Permissions.Extensions["pubkey"]
124 if !ok {
125 return nil
126 }
127
128 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
129 if err != nil {
130 return nil
131 }
132 return pk
133}
134
135func (s *SSHServerConnSession) RemoteAddr() net.Addr {
136 return s.Conn.RemoteAddr()
137}
138
139func (s *SSHServerConnSession) Command() []string {
140 cmd, _ := s.Value("command").([]string)
141 return cmd
142}
143
144func (s *SSHServerConnSession) Close() error {
145 s.CancelFunc()
146 return s.Channel.Close()
147}
148
149func (s *SSHServerConnSession) Exit(code int) error {
150 status := struct{ Status uint32 }{uint32(code)}
151 _, err := s.SendRequest("exit-status", false, ssh.Marshal(&status))
152 return err
153}
154
155func (sesh *SSHServerConnSession) Errorf(f string, v ...interface{}) {
156 _, _ = fmt.Fprintf(sesh.Stderr(), f, v...)
157}
158
159func (sesh *SSHServerConnSession) Errorln(v ...interface{}) {
160 _, _ = fmt.Fprintln(sesh.Stderr(), v...)
161}
162
163func (sesh *SSHServerConnSession) Printf(f string, v ...interface{}) {
164 _, _ = fmt.Fprintf(sesh, f, v...)
165}
166
167func (sesh *SSHServerConnSession) Println(v ...interface{}) {
168 _, _ = fmt.Fprintln(sesh, v...)
169}
170
171func (s *SSHServerConnSession) Fatal(err error) {
172 _, _ = fmt.Fprintln(s.Stderr(), err)
173 _, _ = fmt.Fprintf(s.Stderr(), "\r")
174 _ = s.Exit(1)
175 _ = s.Close()
176}
177
178func (s *SSHServerConnSession) Pty() (*Pty, <-chan Window, bool) {
179 s.mu.Lock()
180 defer s.mu.Unlock()
181
182 if s.pty == nil {
183 return nil, nil, false
184 }
185
186 return s.pty, s.winch, true
187}
188
189// Write overrides the embedded Channel's Write to normalize line endings when PTY is allocated.
190func (s *SSHServerConnSession) Write(p []byte) (n int, err error) {
191 s.mu.Lock()
192 hasPty := s.pty != nil
193 s.mu.Unlock()
194
195 if !hasPty {
196 // No PTY, write as-is
197 return s.Channel.Write(p)
198 }
199
200 // When PTY is active, ensure every \n is preceded by \r.
201 // This ensures the cursor returns to column 0 before the newline.
202 var buf bytes.Buffer
203 for i := 0; i < len(p); i++ {
204 if p[i] == '\n' {
205 // Check if preceded by \r
206 if i == 0 || p[i-1] != '\r' {
207 buf.WriteByte('\r')
208 }
209 buf.WriteByte('\n')
210 } else {
211 buf.WriteByte(p[i])
212 }
213 }
214
215 normalized := buf.Bytes()
216 written, err := s.Channel.Write(normalized)
217
218 // Return the count based on original data length, not normalized
219 if written > len(p) {
220 written = len(p)
221 }
222 return written, err
223}
224
225var _ context.Context = &SSHServerConnSession{}
226
227func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
228 defer func() {
229 _ = sc.Close()
230 }()
231
232 for {
233 select {
234 case <-sc.Context().Done():
235 return nil
236 case newChan, ok := <-chans:
237 if !ok {
238 return nil
239 }
240
241 sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
242 chanFunc, ok := sc.SSHServer.Config.ChannelMiddleware[newChan.ChannelType()]
243 if !ok {
244 sc.Logger.Info("no channel middleware for type", "type", newChan.ChannelType())
245 continue
246 }
247
248 go func() {
249 err := chanFunc(newChan, sc)
250 if err != nil {
251 sc.Logger.Error("channel middleware", "err", err)
252 }
253 }()
254 case req, ok := <-reqs:
255 if !ok {
256 return nil
257 }
258 sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
259 switch req.Type {
260 case "keepalive@openssh.com":
261 sc.Logger.Info("keepalive reply")
262 err := req.Reply(true, nil)
263 if err != nil {
264 sc.Logger.Error("keepalive", "err", err)
265 }
266 }
267 }
268 }
269}
270
271func NewSSHServerConn(
272 ctx context.Context,
273 logger *slog.Logger,
274 conn *ssh.ServerConn,
275 server *SSHServer,
276) *SSHServerConn {
277 if ctx == nil {
278 ctx = context.Background()
279 }
280
281 cancelCtx, cancelFunc := context.WithCancel(ctx)
282
283 if logger == nil {
284 logger = slog.Default()
285 }
286
287 return &SSHServerConn{
288 Ctx: cancelCtx,
289 CancelFunc: cancelFunc,
290 Logger: logger,
291 Conn: conn,
292 SSHServer: server,
293 Start: time.Now(),
294 }
295}
296
297type SSHServerHandler func(*SSHServerConnSession) error
298type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
299type SSHServerChannelMiddleware func(ssh.NewChannel, *SSHServerConn) error
300
301type SSHServerConfig struct {
302 *ssh.ServerConfig
303 App string
304 ListenAddr string
305 PromListenAddr string
306 Middleware []SSHServerMiddleware
307 SubsystemMiddleware []SSHServerMiddleware
308 ChannelMiddleware map[string]SSHServerChannelMiddleware
309}
310
311type SSHServer struct {
312 Ctx context.Context
313 CancelFunc context.CancelFunc
314 Logger *slog.Logger
315 Config *SSHServerConfig
316 Listener net.Listener
317 Conns *syncmap.Map[string, *SSHServerConn]
318
319 SessionsCreated *prometheus.CounterVec
320 SessionsFinished *prometheus.CounterVec
321 SessionsDuration *prometheus.CounterVec
322
323 Mu sync.Mutex
324}
325
326func (s *SSHServer) ListenAndServe() error {
327 if s.Config.PromListenAddr != "" {
328 s.SessionsCreated = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
329 Name: "pssh_sessions_created_total",
330 Help: "The total number of sessions created",
331 ConstLabels: prometheus.Labels{
332 "app": s.Config.App,
333 },
334 }, []string{"command"})
335
336 s.SessionsFinished = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
337 Name: "pssh_sessions_finished_total",
338 Help: "The total number of sessions created",
339 ConstLabels: prometheus.Labels{
340 "app": s.Config.App,
341 },
342 }, []string{"command"})
343
344 s.SessionsDuration = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
345 Name: "pssh_sessions_duration_seconds",
346 Help: "The total sessions duration in seconds",
347 ConstLabels: prometheus.Labels{
348 "app": s.Config.App,
349 },
350 }, []string{"command"})
351
352 go func() {
353 mux := http.NewServeMux()
354 mux.Handle("/metrics", promhttp.Handler())
355
356 srv := &http.Server{Addr: s.Config.PromListenAddr, Handler: mux}
357
358 go func() {
359 <-s.Ctx.Done()
360 s.Logger.Info("Prometheus server shutting down")
361 _ = srv.Close()
362 }()
363
364 s.Logger.Info("Starting Prometheus server", "addr", s.Config.PromListenAddr)
365
366 err := srv.ListenAndServe()
367 if err != nil {
368 if errors.Is(err, http.ErrServerClosed) {
369 s.Logger.Info("Prometheus server shut down")
370 return
371 }
372
373 s.Logger.Error("Prometheus serve error", "err", err)
374 panic(err)
375 }
376 }()
377 }
378
379 listen, err := net.Listen("tcp", s.Config.ListenAddr)
380 if err != nil {
381 return err
382 }
383
384 s.Mu.Lock()
385 s.Listener = listen
386 s.Mu.Unlock()
387 defer func() {
388 _ = s.Listener.Close()
389 }()
390
391 go func() {
392 <-s.Ctx.Done()
393 _ = s.Close()
394 }()
395
396 var retErr error
397
398 for {
399 conn, err := s.Listener.Accept()
400 if err != nil {
401 s.Logger.Error("accept", "err", err)
402 if errors.Is(err, net.ErrClosed) {
403 retErr = err
404 break
405 }
406 continue
407 }
408
409 go func() {
410 if err := s.HandleConn(conn); err != nil && !errors.Is(err, io.EOF) {
411 s.Logger.Error("Error handling connection", "err", err, "remoteAddr", conn.RemoteAddr().String())
412 }
413 }()
414 }
415
416 if errors.Is(retErr, net.ErrClosed) {
417 return nil
418 }
419
420 return retErr
421}
422
423func (s *SSHServer) HandleConn(conn net.Conn) error {
424 defer func() {
425 _ = conn.Close()
426 }()
427
428 sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
429 if err != nil {
430 return err
431 }
432
433 newLogger := s.Logger.With(
434 "remoteAddr", conn.RemoteAddr().String(),
435 "sshUser", sshConn.User(),
436 )
437
438 if pubKey, ok := sshConn.Permissions.Extensions["pubkey"]; ok {
439 newLogger = newLogger.With("pubkey", pubKey)
440 }
441
442 newConn := NewSSHServerConn(
443 s.Ctx,
444 newLogger,
445 sshConn,
446 s,
447 )
448
449 s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
450
451 err = newConn.Handle(chans, reqs)
452
453 s.Conns.Delete(sshConn.RemoteAddr().String())
454
455 return err
456}
457
458func (s *SSHServer) Close() error {
459 s.CancelFunc()
460 return s.Listener.Close()
461}
462
463func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
464 if ctx == nil {
465 ctx = context.Background()
466 }
467
468 cancelCtx, cancelFunc := context.WithCancel(ctx)
469
470 if logger == nil {
471 logger = slog.Default()
472 }
473
474 if config == nil {
475 config = &SSHServerConfig{}
476 }
477
478 if config.ChannelMiddleware == nil {
479 config.ChannelMiddleware = map[string]SSHServerChannelMiddleware{}
480 }
481
482 if _, ok := config.ChannelMiddleware["session"]; !ok {
483 config.ChannelMiddleware["session"] = func(newChan ssh.NewChannel, sc *SSHServerConn) error {
484 channel, requests, err := newChan.Accept()
485 if err != nil {
486 sc.Logger.Error("accept session channel", "err", err)
487 return err
488 }
489
490 ctx, cancelFunc := context.WithCancel(sc.Ctx)
491
492 sesh := &SSHServerConnSession{
493 Channel: channel,
494 SSHServerConn: sc,
495 Ctx: ctx,
496 CancelFunc: cancelFunc,
497 }
498
499 for {
500 select {
501 case <-sesh.Done():
502 return nil
503 case req, ok := <-requests:
504 if !ok {
505 return nil
506 }
507
508 go func() {
509 sc.Logger.Info("new session request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
510 switch req.Type {
511 case "subsystem":
512 if len(sc.SSHServer.Config.SubsystemMiddleware) == 0 {
513 err := req.Reply(false, nil)
514 if err != nil {
515 sc.Logger.Error("subsystem reply", "err", err)
516 }
517
518 err = sc.Close()
519 if err != nil {
520 sc.Logger.Error("subsystem close", "err", err)
521 }
522
523 sesh.Fatal(err)
524 return
525 }
526
527 h := func(*SSHServerConnSession) error { return nil }
528 for _, m := range sc.SSHServer.Config.SubsystemMiddleware {
529 h = m(h)
530 }
531
532 err := req.Reply(true, nil)
533 if err != nil {
534 sc.Logger.Error("subsystem reply", "err", err)
535 sesh.Fatal(err)
536 return
537 }
538
539 if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
540 sc.Logger.Error("subsystem middleware", "err", err)
541 sesh.Fatal(err)
542 return
543 }
544
545 err = sesh.Exit(0)
546 if err != nil {
547 sc.Logger.Error("subsystem exit", "err", err)
548 }
549
550 err = sesh.Close()
551 if err != nil {
552 sc.Logger.Error("subsystem close", "err", err)
553 }
554 case "shell", "exec":
555 if len(sc.SSHServer.Config.Middleware) == 0 {
556 err := req.Reply(false, nil)
557 if err != nil {
558 sc.Logger.Error("shell/exec reply", "err", err)
559 }
560 sesh.Fatal(err)
561 return
562 }
563
564 command := "shell"
565
566 if len(req.Payload) > 0 {
567 var payload = struct{ Value string }{}
568 err := ssh.Unmarshal(req.Payload, &payload)
569 if err != nil {
570 sc.Logger.Error("shell/exec unmarshal", "err", err)
571 sesh.Fatal(err)
572 return
573 }
574
575 commaSplitter, _ := splitter.NewSplitter(
576 ' ',
577 splitter.DoubleQuotes,
578 splitter.SingleQuotes,
579 )
580 command = payload.Value
581 cmdSlice, _ := commaSplitter.Split(command, splitter.StripQuotes)
582 sesh.SetValue("command", cmdSlice)
583 }
584
585 if !utf8.ValidString(command) {
586 command = base64.StdEncoding.EncodeToString([]byte(command))
587 }
588
589 if sc.SSHServer.Config.PromListenAddr != "" {
590 sc.SSHServer.SessionsCreated.WithLabelValues(command).Inc()
591 defer func() {
592 sc.SSHServer.SessionsFinished.WithLabelValues(command).Inc()
593 sc.SSHServer.SessionsDuration.WithLabelValues(command).Add(time.Since(sc.Start).Seconds())
594 }()
595 }
596
597 h := func(*SSHServerConnSession) error { return nil }
598 for _, m := range sc.SSHServer.Config.Middleware {
599 h = m(h)
600 }
601
602 err = req.Reply(true, nil)
603 if err != nil {
604 sc.Logger.Error("shell/exec reply", "err", err)
605 sesh.Fatal(err)
606 return
607 }
608
609 if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
610 sc.Logger.Error("exec middleware", "err", err)
611 sesh.Fatal(err)
612 return
613 }
614
615 err = sesh.Exit(0)
616 if err != nil {
617 sc.Logger.Error("subsystem exit", "err", err)
618 }
619
620 err = sesh.Close()
621 if err != nil {
622 sc.Logger.Error("subsystem close", "err", err)
623 }
624 case "pty-req":
625 sesh.mu.Lock()
626 found := sesh.pty != nil
627 sesh.mu.Unlock()
628 if found {
629 err := req.Reply(false, nil)
630 if err != nil {
631 sc.Logger.Error("pty-req reply", "err", err)
632 }
633 return
634 }
635
636 ptyReq, ok := parsePtyRequest(req.Payload)
637 if !ok {
638 err := req.Reply(false, nil)
639 if err != nil {
640 sc.Logger.Error("pty-req reply", "err", err)
641 }
642 return
643 }
644
645 sesh.mu.Lock()
646 sesh.pty = &ptyReq
647 sesh.winch = make(chan Window, 1)
648 sesh.mu.Unlock()
649
650 sesh.winch <- ptyReq.Window
651 err := req.Reply(ok, nil)
652 if err != nil {
653 sc.Logger.Error("pty-req reply", "err", err)
654 }
655 case "window-change":
656 sesh.mu.Lock()
657 found := sesh.pty != nil
658 sesh.mu.Unlock()
659
660 if !found {
661 err := req.Reply(false, nil)
662 if err != nil {
663 sc.Logger.Error("pty-req reply", "err", err)
664 }
665 return
666 }
667
668 win, ok := parseWinchRequest(req.Payload)
669 if ok {
670 sesh.mu.Lock()
671 sesh.pty.Window = win
672 sesh.winch <- win
673 sesh.mu.Unlock()
674 }
675
676 err := req.Reply(ok, nil)
677 if err != nil {
678 sc.Logger.Error("window-change reply", "err", err)
679 }
680 }
681 }()
682 }
683 }
684 }
685 }
686
687 server := &SSHServer{
688 Ctx: cancelCtx,
689 CancelFunc: cancelFunc,
690 Logger: logger,
691 Config: config,
692 Conns: syncmap.New[string, *SSHServerConn](),
693 }
694
695 return server
696}
697
698type PubKeyAuthHandler func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
699
700func NewSSHServerWithConfig(
701 ctx context.Context,
702 logger *slog.Logger,
703 app, host, port, promPort, hostKey string,
704 pubKeyAuthHandler PubKeyAuthHandler,
705 middleware, subsystemMiddleware []SSHServerMiddleware,
706 channelMiddleware map[string]SSHServerChannelMiddleware,
707) (*SSHServer, error) {
708 server := NewSSHServer(ctx, logger, &SSHServerConfig{
709 App: app,
710 ListenAddr: fmt.Sprintf("%s:%s", host, port),
711 ServerConfig: &ssh.ServerConfig{
712 PublicKeyCallback: pubKeyAuthHandler,
713 },
714 Middleware: middleware,
715 SubsystemMiddleware: subsystemMiddleware,
716 ChannelMiddleware: channelMiddleware,
717 })
718
719 if promPort != "" {
720 server.Config.PromListenAddr = fmt.Sprintf("%s:%s", host, promPort)
721 }
722
723 if hostKey != "" {
724 pemBytes, err := os.ReadFile(hostKey)
725 if err != nil {
726 logger.Error("failed to read private key file", "error", err)
727 if !os.IsNotExist(err) {
728 return nil, err
729 }
730
731 logger.Info("generating new private key")
732
733 pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
734 if err != nil {
735 logger.Error("failed to generate private key", "error", err)
736 return nil, err
737 }
738
739 privb, err := ssh.MarshalPrivateKey(privKey, "")
740 if err != nil {
741 logger.Error("failed to marshal private key", "error", err)
742 return nil, err
743 }
744
745 block := &pem.Block{
746 Type: "OPENSSH PRIVATE KEY",
747 Bytes: privb.Bytes,
748 }
749
750 if err = os.MkdirAll(path.Dir(hostKey), 0700); err != nil {
751 logger.Error("failed to create ssh_data directory", "error", err)
752 return nil, err
753 }
754
755 pemBytes = pem.EncodeToMemory(block)
756
757 if err = os.WriteFile(hostKey, pemBytes, 0600); err != nil {
758 logger.Error("failed to write private key", "error", err)
759 return nil, err
760 }
761
762 sshPubKey, err := ssh.NewPublicKey(pubKey)
763 if err != nil {
764 logger.Error("failed to create public key", "error", err)
765 return nil, err
766 }
767
768 pubb := ssh.MarshalAuthorizedKey(sshPubKey)
769 if err = os.WriteFile(fmt.Sprintf("%s.pub", hostKey), pubb, 0600); err != nil {
770 logger.Error("failed to write public key", "error", err)
771 return nil, err
772 }
773 }
774
775 signer, err := ssh.ParsePrivateKey(pemBytes)
776 if err != nil {
777 logger.Error("failed to parse private key", "error", err)
778 return nil, err
779 }
780
781 server.Config.AddHostKey(signer)
782 }
783
784 return server, nil
785}
786
787func KeysEqual(a, b ssh.PublicKey) bool {
788 if a == nil || b == nil {
789 return false
790 }
791
792 am := a.Marshal()
793 bm := b.Marshal()
794 return (len(am) == len(bm) && subtle.ConstantTimeCompare(am, bm) == 1)
795}