Eric Bower
·
2025-07-21
server.go
1package pssh
2
3import (
4 "context"
5 "crypto/ed25519"
6 "crypto/rand"
7 "crypto/subtle"
8 "encoding/base64"
9 "encoding/pem"
10 "errors"
11 "fmt"
12 "io"
13 "log/slog"
14 "net"
15 "net/http"
16 "os"
17 "path"
18 "strings"
19 "sync"
20 "time"
21 "unicode/utf8"
22
23 "github.com/antoniomika/syncmap"
24 "github.com/prometheus/client_golang/prometheus"
25 "github.com/prometheus/client_golang/prometheus/promauto"
26 "github.com/prometheus/client_golang/prometheus/promhttp"
27 "golang.org/x/crypto/ssh"
28)
29
30type SSHServerConn struct {
31 Ctx context.Context
32 CancelFunc context.CancelFunc
33 Logger *slog.Logger
34 Conn *ssh.ServerConn
35 SSHServer *SSHServer
36 Start time.Time
37
38 mu sync.Mutex
39}
40
41func (s *SSHServerConn) Context() context.Context {
42 s.mu.Lock()
43 defer s.mu.Unlock()
44
45 return s.Ctx
46}
47
48func (sc *SSHServerConn) Close() error {
49 sc.CancelFunc()
50 return nil
51}
52
53type SSHServerConnSession struct {
54 ssh.Channel
55 *SSHServerConn
56
57 Ctx context.Context
58 CancelFunc context.CancelFunc
59
60 pty *Pty
61 winch chan Window
62
63 mu sync.Mutex
64}
65
66// Deadline implements context.Context.
67func (s *SSHServerConnSession) Deadline() (deadline time.Time, ok bool) {
68 s.mu.Lock()
69 defer s.mu.Unlock()
70
71 return s.Ctx.Deadline()
72}
73
74// Done implements context.Context.
75func (s *SSHServerConnSession) Done() <-chan struct{} {
76 s.mu.Lock()
77 defer s.mu.Unlock()
78
79 return s.Ctx.Done()
80}
81
82// Err implements context.Context.
83func (s *SSHServerConnSession) Err() error {
84 s.mu.Lock()
85 defer s.mu.Unlock()
86
87 return s.Ctx.Err()
88}
89
90// Value implements context.Context.
91func (s *SSHServerConnSession) Value(key any) any {
92 s.mu.Lock()
93 defer s.mu.Unlock()
94
95 return s.Ctx.Value(key)
96}
97
98// SetValue implements context.Context.
99func (s *SSHServerConnSession) SetValue(key any, data any) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 s.Ctx = context.WithValue(s.Ctx, key, data)
104}
105
106func (s *SSHServerConnSession) Context() context.Context {
107 s.mu.Lock()
108 defer s.mu.Unlock()
109
110 return s.Ctx
111}
112
113func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
114 return s.Conn.Permissions
115}
116
117func (s *SSHServerConnSession) User() string {
118 return s.Conn.User()
119}
120
121func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
122 key, ok := s.Conn.Permissions.Extensions["pubkey"]
123 if !ok {
124 return nil
125 }
126
127 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
128 if err != nil {
129 return nil
130 }
131 return pk
132}
133
134func (s *SSHServerConnSession) RemoteAddr() net.Addr {
135 return s.Conn.RemoteAddr()
136}
137
138func (s *SSHServerConnSession) Command() []string {
139 cmd, _ := s.Value("command").([]string)
140 return cmd
141}
142
143func (s *SSHServerConnSession) Close() error {
144 s.CancelFunc()
145 return s.Channel.Close()
146}
147
148func (s *SSHServerConnSession) Exit(code int) error {
149 status := struct{ Status uint32 }{uint32(code)}
150 _, err := s.SendRequest("exit-status", false, ssh.Marshal(&status))
151 return err
152}
153
154func (s *SSHServerConnSession) Fatal(err error) {
155 _, _ = fmt.Fprintln(s.Stderr(), err)
156 _, _ = fmt.Fprintf(s.Stderr(), "\r")
157 _ = s.Exit(1)
158 _ = s.Close()
159}
160
161func (s *SSHServerConnSession) Pty() (*Pty, <-chan Window, bool) {
162 s.mu.Lock()
163 defer s.mu.Unlock()
164
165 if s.pty == nil {
166 return nil, nil, false
167 }
168
169 return s.pty, s.winch, true
170}
171
172var _ context.Context = &SSHServerConnSession{}
173
174func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
175 defer func() {
176 _ = sc.Close()
177 }()
178
179 for {
180 select {
181 case <-sc.Context().Done():
182 return nil
183 case newChan, ok := <-chans:
184 if !ok {
185 return nil
186 }
187
188 sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
189 chanFunc, ok := sc.SSHServer.Config.ChannelMiddleware[newChan.ChannelType()]
190 if !ok {
191 sc.Logger.Info("no channel middleware for type", "type", newChan.ChannelType())
192 continue
193 }
194
195 go func() {
196 err := chanFunc(newChan, sc)
197 if err != nil {
198 sc.Logger.Error("channel middleware", "err", err)
199 }
200 }()
201 case req, ok := <-reqs:
202 if !ok {
203 return nil
204 }
205 sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
206 switch req.Type {
207 case "keepalive@openssh.com":
208 sc.Logger.Info("keepalive reply")
209 err := req.Reply(true, nil)
210 if err != nil {
211 sc.Logger.Error("keepalive", "err", err)
212 }
213 }
214 }
215 }
216}
217
218func NewSSHServerConn(
219 ctx context.Context,
220 logger *slog.Logger,
221 conn *ssh.ServerConn,
222 server *SSHServer,
223) *SSHServerConn {
224 if ctx == nil {
225 ctx = context.Background()
226 }
227
228 cancelCtx, cancelFunc := context.WithCancel(ctx)
229
230 if logger == nil {
231 logger = slog.Default()
232 }
233
234 return &SSHServerConn{
235 Ctx: cancelCtx,
236 CancelFunc: cancelFunc,
237 Logger: logger,
238 Conn: conn,
239 SSHServer: server,
240 Start: time.Now(),
241 }
242}
243
244type SSHServerHandler func(*SSHServerConnSession) error
245type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
246type SSHServerChannelMiddleware func(ssh.NewChannel, *SSHServerConn) error
247
248type SSHServerConfig struct {
249 *ssh.ServerConfig
250 App string
251 ListenAddr string
252 PromListenAddr string
253 Middleware []SSHServerMiddleware
254 SubsystemMiddleware []SSHServerMiddleware
255 ChannelMiddleware map[string]SSHServerChannelMiddleware
256}
257
258type SSHServer struct {
259 Ctx context.Context
260 CancelFunc context.CancelFunc
261 Logger *slog.Logger
262 Config *SSHServerConfig
263 Listener net.Listener
264 Conns *syncmap.Map[string, *SSHServerConn]
265
266 SessionsCreated *prometheus.CounterVec
267 SessionsFinished *prometheus.CounterVec
268 SessionsDuration *prometheus.CounterVec
269}
270
271func (s *SSHServer) ListenAndServe() error {
272 if s.Config.PromListenAddr != "" {
273 s.SessionsCreated = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
274 Name: "pssh_sessions_created_total",
275 Help: "The total number of sessions created",
276 ConstLabels: prometheus.Labels{
277 "app": s.Config.App,
278 },
279 }, []string{"command"})
280
281 s.SessionsFinished = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
282 Name: "pssh_sessions_finished_total",
283 Help: "The total number of sessions created",
284 ConstLabels: prometheus.Labels{
285 "app": s.Config.App,
286 },
287 }, []string{"command"})
288
289 s.SessionsDuration = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
290 Name: "pssh_sessions_duration_seconds",
291 Help: "The total sessions duration in seconds",
292 ConstLabels: prometheus.Labels{
293 "app": s.Config.App,
294 },
295 }, []string{"command"})
296
297 go func() {
298 mux := http.NewServeMux()
299 mux.Handle("/metrics", promhttp.Handler())
300
301 srv := &http.Server{Addr: s.Config.PromListenAddr, Handler: mux}
302
303 go func() {
304 <-s.Ctx.Done()
305 s.Logger.Info("Prometheus server shutting down")
306 _ = srv.Close()
307 }()
308
309 s.Logger.Info("Starting Prometheus server", "addr", s.Config.PromListenAddr)
310
311 err := srv.ListenAndServe()
312 if err != nil {
313 if errors.Is(err, http.ErrServerClosed) {
314 s.Logger.Info("Prometheus server shut down")
315 return
316 }
317
318 s.Logger.Error("Prometheus serve error", "err", err)
319 panic(err)
320 }
321 }()
322 }
323
324 listen, err := net.Listen("tcp", s.Config.ListenAddr)
325 if err != nil {
326 return err
327 }
328
329 s.Listener = listen
330 defer func() {
331 _ = s.Listener.Close()
332 }()
333
334 go func() {
335 <-s.Ctx.Done()
336 _ = s.Close()
337 }()
338
339 var retErr error
340
341 for {
342 conn, err := s.Listener.Accept()
343 if err != nil {
344 s.Logger.Error("accept", "err", err)
345 if errors.Is(err, net.ErrClosed) {
346 retErr = err
347 break
348 }
349 continue
350 }
351
352 go func() {
353 if err := s.HandleConn(conn); err != nil && !errors.Is(err, io.EOF) {
354 s.Logger.Error("Error handling connection", "err", err, "remoteAddr", conn.RemoteAddr().String())
355 }
356 }()
357 }
358
359 if errors.Is(retErr, net.ErrClosed) {
360 return nil
361 }
362
363 return retErr
364}
365
366func (s *SSHServer) HandleConn(conn net.Conn) error {
367 defer func() {
368 _ = conn.Close()
369 }()
370
371 sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
372 if err != nil {
373 return err
374 }
375
376 newLogger := s.Logger.With(
377 "remoteAddr", conn.RemoteAddr().String(),
378 "sshUser", sshConn.User(),
379 )
380
381 if pubKey, ok := sshConn.Permissions.Extensions["pubkey"]; ok {
382 newLogger = newLogger.With("pubkey", pubKey)
383 }
384
385 newConn := NewSSHServerConn(
386 s.Ctx,
387 newLogger,
388 sshConn,
389 s,
390 )
391
392 s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
393
394 err = newConn.Handle(chans, reqs)
395
396 s.Conns.Delete(sshConn.RemoteAddr().String())
397
398 return err
399}
400
401func (s *SSHServer) Close() error {
402 s.CancelFunc()
403 return s.Listener.Close()
404}
405
406func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
407 if ctx == nil {
408 ctx = context.Background()
409 }
410
411 cancelCtx, cancelFunc := context.WithCancel(ctx)
412
413 if logger == nil {
414 logger = slog.Default()
415 }
416
417 if config == nil {
418 config = &SSHServerConfig{}
419 }
420
421 if config.ChannelMiddleware == nil {
422 config.ChannelMiddleware = map[string]SSHServerChannelMiddleware{}
423 }
424
425 if _, ok := config.ChannelMiddleware["session"]; !ok {
426 config.ChannelMiddleware["session"] = func(newChan ssh.NewChannel, sc *SSHServerConn) error {
427 channel, requests, err := newChan.Accept()
428 if err != nil {
429 sc.Logger.Error("accept session channel", "err", err)
430 return err
431 }
432
433 ctx, cancelFunc := context.WithCancel(sc.Ctx)
434
435 sesh := &SSHServerConnSession{
436 Channel: channel,
437 SSHServerConn: sc,
438 Ctx: ctx,
439 CancelFunc: cancelFunc,
440 }
441
442 for {
443 select {
444 case <-sesh.Done():
445 return nil
446 case req, ok := <-requests:
447 if !ok {
448 return nil
449 }
450
451 go func() {
452 sc.Logger.Info("new session request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
453 switch req.Type {
454 case "subsystem":
455 if len(sc.SSHServer.Config.SubsystemMiddleware) == 0 {
456 err := req.Reply(false, nil)
457 if err != nil {
458 sc.Logger.Error("subsystem reply", "err", err)
459 }
460
461 err = sc.Close()
462 if err != nil {
463 sc.Logger.Error("subsystem close", "err", err)
464 }
465
466 sesh.Fatal(err)
467 return
468 }
469
470 h := func(*SSHServerConnSession) error { return nil }
471 for _, m := range sc.SSHServer.Config.SubsystemMiddleware {
472 h = m(h)
473 }
474
475 err := req.Reply(true, nil)
476 if err != nil {
477 sc.Logger.Error("subsystem reply", "err", err)
478 sesh.Fatal(err)
479 return
480 }
481
482 if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
483 sc.Logger.Error("subsystem middleware", "err", err)
484 sesh.Fatal(err)
485 return
486 }
487
488 err = sesh.Exit(0)
489 if err != nil {
490 sc.Logger.Error("subsystem exit", "err", err)
491 }
492
493 err = sesh.Close()
494 if err != nil {
495 sc.Logger.Error("subsystem close", "err", err)
496 }
497 case "shell", "exec":
498 if len(sc.SSHServer.Config.Middleware) == 0 {
499 err := req.Reply(false, nil)
500 if err != nil {
501 sc.Logger.Error("shell/exec reply", "err", err)
502 }
503 sesh.Fatal(err)
504 return
505 }
506
507 command := "shell"
508
509 if len(req.Payload) > 0 {
510 var payload = struct{ Value string }{}
511 err := ssh.Unmarshal(req.Payload, &payload)
512 if err != nil {
513 sc.Logger.Error("shell/exec unmarshal", "err", err)
514 sesh.Fatal(err)
515 return
516 }
517
518 command = payload.Value
519
520 sesh.SetValue("command", strings.Fields(payload.Value))
521 }
522
523 if !utf8.ValidString(command) {
524 command = base64.StdEncoding.EncodeToString([]byte(command))
525 }
526
527 if sc.SSHServer.Config.PromListenAddr != "" {
528 sc.SSHServer.SessionsCreated.WithLabelValues(command).Inc()
529 defer func() {
530 sc.SSHServer.SessionsFinished.WithLabelValues(command).Inc()
531 sc.SSHServer.SessionsDuration.WithLabelValues(command).Add(time.Since(sc.Start).Seconds())
532 }()
533 }
534
535 h := func(*SSHServerConnSession) error { return nil }
536 for _, m := range sc.SSHServer.Config.Middleware {
537 h = m(h)
538 }
539
540 err = req.Reply(true, nil)
541 if err != nil {
542 sc.Logger.Error("shell/exec reply", "err", err)
543 sesh.Fatal(err)
544 return
545 }
546
547 if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
548 sc.Logger.Error("exec middleware", "err", err)
549 sesh.Fatal(err)
550 return
551 }
552
553 err = sesh.Exit(0)
554 if err != nil {
555 sc.Logger.Error("subsystem exit", "err", err)
556 }
557
558 err = sesh.Close()
559 if err != nil {
560 sc.Logger.Error("subsystem close", "err", err)
561 }
562 case "pty-req":
563 sesh.mu.Lock()
564 found := sesh.pty != nil
565 sesh.mu.Unlock()
566 if found {
567 err := req.Reply(false, nil)
568 if err != nil {
569 sc.Logger.Error("pty-req reply", "err", err)
570 }
571 return
572 }
573
574 ptyReq, ok := parsePtyRequest(req.Payload)
575 if !ok {
576 err := req.Reply(false, nil)
577 if err != nil {
578 sc.Logger.Error("pty-req reply", "err", err)
579 }
580 return
581 }
582
583 sesh.mu.Lock()
584 sesh.pty = &ptyReq
585 sesh.winch = make(chan Window, 1)
586 sesh.mu.Unlock()
587
588 sesh.winch <- ptyReq.Window
589 err := req.Reply(ok, nil)
590 if err != nil {
591 sc.Logger.Error("pty-req reply", "err", err)
592 }
593 case "window-change":
594 sesh.mu.Lock()
595 found := sesh.pty != nil
596 sesh.mu.Unlock()
597
598 if !found {
599 err := req.Reply(false, nil)
600 if err != nil {
601 sc.Logger.Error("pty-req reply", "err", err)
602 }
603 return
604 }
605
606 win, ok := parseWinchRequest(req.Payload)
607 if ok {
608 sesh.mu.Lock()
609 sesh.pty.Window = win
610 sesh.winch <- win
611 sesh.mu.Unlock()
612 }
613
614 err := req.Reply(ok, nil)
615 if err != nil {
616 sc.Logger.Error("window-change reply", "err", err)
617 }
618 }
619 }()
620 }
621 }
622 }
623 }
624
625 server := &SSHServer{
626 Ctx: cancelCtx,
627 CancelFunc: cancelFunc,
628 Logger: logger,
629 Config: config,
630 Conns: syncmap.New[string, *SSHServerConn](),
631 }
632
633 return server
634}
635
636type PubKeyAuthHandler func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
637
638func NewSSHServerWithConfig(
639 ctx context.Context,
640 logger *slog.Logger,
641 app, host, port, promPort, hostKey string,
642 pubKeyAuthHandler PubKeyAuthHandler,
643 middleware, subsystemMiddleware []SSHServerMiddleware,
644 channelMiddleware map[string]SSHServerChannelMiddleware,
645) (*SSHServer, error) {
646 server := NewSSHServer(ctx, logger, &SSHServerConfig{
647 App: app,
648 ListenAddr: fmt.Sprintf("%s:%s", host, port),
649 ServerConfig: &ssh.ServerConfig{
650 PublicKeyCallback: pubKeyAuthHandler,
651 },
652 Middleware: middleware,
653 SubsystemMiddleware: subsystemMiddleware,
654 ChannelMiddleware: channelMiddleware,
655 })
656
657 if promPort != "" {
658 server.Config.PromListenAddr = fmt.Sprintf("%s:%s", host, promPort)
659 }
660
661 if hostKey != "" {
662 pemBytes, err := os.ReadFile(hostKey)
663 if err != nil {
664 logger.Error("failed to read private key file", "error", err)
665 if !os.IsNotExist(err) {
666 return nil, err
667 }
668
669 logger.Info("generating new private key")
670
671 pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
672 if err != nil {
673 logger.Error("failed to generate private key", "error", err)
674 return nil, err
675 }
676
677 privb, err := ssh.MarshalPrivateKey(privKey, "")
678 if err != nil {
679 logger.Error("failed to marshal private key", "error", err)
680 return nil, err
681 }
682
683 block := &pem.Block{
684 Type: "OPENSSH PRIVATE KEY",
685 Bytes: privb.Bytes,
686 }
687
688 if err = os.MkdirAll(path.Dir(hostKey), 0700); err != nil {
689 logger.Error("failed to create ssh_data directory", "error", err)
690 return nil, err
691 }
692
693 pemBytes = pem.EncodeToMemory(block)
694
695 if err = os.WriteFile(hostKey, pemBytes, 0600); err != nil {
696 logger.Error("failed to write private key", "error", err)
697 return nil, err
698 }
699
700 sshPubKey, err := ssh.NewPublicKey(pubKey)
701 if err != nil {
702 logger.Error("failed to create public key", "error", err)
703 return nil, err
704 }
705
706 pubb := ssh.MarshalAuthorizedKey(sshPubKey)
707 if err = os.WriteFile(fmt.Sprintf("%s.pub", hostKey), pubb, 0600); err != nil {
708 logger.Error("failed to write public key", "error", err)
709 return nil, err
710 }
711 }
712
713 signer, err := ssh.ParsePrivateKey(pemBytes)
714 if err != nil {
715 logger.Error("failed to parse private key", "error", err)
716 return nil, err
717 }
718
719 server.Config.AddHostKey(signer)
720 }
721
722 return server, nil
723}
724
725func KeysEqual(a, b ssh.PublicKey) bool {
726 if a == nil || b == nil {
727 return false
728 }
729
730 am := a.Marshal()
731 bm := b.Marshal()
732 return (len(am) == len(bm) && subtle.ConstantTimeCompare(am, bm) == 1)
733}