Antonio Mika
ยท
2025-04-10
server.go
1package pssh
2
3import (
4 "context"
5 "crypto/ed25519"
6 "crypto/rand"
7 "crypto/subtle"
8 "encoding/base64"
9 "encoding/pem"
10 "errors"
11 "fmt"
12 "io"
13 "log/slog"
14 "net"
15 "net/http"
16 "os"
17 "path"
18 "strings"
19 "sync"
20 "time"
21 "unicode/utf8"
22
23 "github.com/antoniomika/syncmap"
24 "github.com/prometheus/client_golang/prometheus"
25 "github.com/prometheus/client_golang/prometheus/promauto"
26 "github.com/prometheus/client_golang/prometheus/promhttp"
27 "golang.org/x/crypto/ssh"
28)
29
30type SSHServerConn struct {
31 Ctx context.Context
32 CancelFunc context.CancelFunc
33 Logger *slog.Logger
34 Conn *ssh.ServerConn
35 SSHServer *SSHServer
36 Start time.Time
37
38 mu sync.Mutex
39}
40
41func (s *SSHServerConn) Context() context.Context {
42 s.mu.Lock()
43 defer s.mu.Unlock()
44
45 return s.Ctx
46}
47
48func (sc *SSHServerConn) Close() error {
49 sc.CancelFunc()
50 return nil
51}
52
53type SSHServerConnSession struct {
54 ssh.Channel
55 *SSHServerConn
56
57 Ctx context.Context
58 CancelFunc context.CancelFunc
59
60 pty *Pty
61 winch chan Window
62
63 mu sync.Mutex
64}
65
66// Deadline implements context.Context.
67func (s *SSHServerConnSession) Deadline() (deadline time.Time, ok bool) {
68 s.mu.Lock()
69 defer s.mu.Unlock()
70
71 return s.Ctx.Deadline()
72}
73
74// Done implements context.Context.
75func (s *SSHServerConnSession) Done() <-chan struct{} {
76 s.mu.Lock()
77 defer s.mu.Unlock()
78
79 return s.Ctx.Done()
80}
81
82// Err implements context.Context.
83func (s *SSHServerConnSession) Err() error {
84 s.mu.Lock()
85 defer s.mu.Unlock()
86
87 return s.Ctx.Err()
88}
89
90// Value implements context.Context.
91func (s *SSHServerConnSession) Value(key any) any {
92 s.mu.Lock()
93 defer s.mu.Unlock()
94
95 return s.Ctx.Value(key)
96}
97
98// SetValue implements context.Context.
99func (s *SSHServerConnSession) SetValue(key any, data any) {
100 s.mu.Lock()
101 defer s.mu.Unlock()
102
103 s.Ctx = context.WithValue(s.Ctx, key, data)
104}
105
106func (s *SSHServerConnSession) Context() context.Context {
107 s.mu.Lock()
108 defer s.mu.Unlock()
109
110 return s.Ctx
111}
112
113func (s *SSHServerConnSession) Permissions() *ssh.Permissions {
114 return s.Conn.Permissions
115}
116
117func (s *SSHServerConnSession) User() string {
118 return s.Conn.User()
119}
120
121func (s *SSHServerConnSession) PublicKey() ssh.PublicKey {
122 key, ok := s.Conn.Permissions.Extensions["pubkey"]
123 if !ok {
124 return nil
125 }
126
127 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
128 if err != nil {
129 return nil
130 }
131 return pk
132}
133
134func (s *SSHServerConnSession) RemoteAddr() net.Addr {
135 return s.Conn.RemoteAddr()
136}
137
138func (s *SSHServerConnSession) Command() []string {
139 cmd, _ := s.Value("command").([]string)
140 return cmd
141}
142
143func (s *SSHServerConnSession) Close() error {
144 s.CancelFunc()
145 return s.Channel.Close()
146}
147
148func (s *SSHServerConnSession) Exit(code int) error {
149 status := struct{ Status uint32 }{uint32(code)}
150 _, err := s.Channel.SendRequest("exit-status", false, ssh.Marshal(&status))
151 return err
152}
153
154func (s *SSHServerConnSession) Fatal(err error) {
155 fmt.Fprintln(s.Stderr(), err)
156 fmt.Fprintf(s.Stderr(), "\r")
157 _ = s.Exit(1)
158 _ = s.Close()
159}
160
161func (s *SSHServerConnSession) Pty() (*Pty, <-chan Window, bool) {
162 s.mu.Lock()
163 defer s.mu.Unlock()
164
165 if s.pty == nil {
166 return nil, nil, false
167 }
168
169 return s.pty, s.winch, true
170}
171
172var _ context.Context = &SSHServerConnSession{}
173
174func (sc *SSHServerConn) Handle(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) error {
175 defer sc.Close()
176
177 for {
178 select {
179 case <-sc.Context().Done():
180 return nil
181 case newChan, ok := <-chans:
182 if !ok {
183 return nil
184 }
185
186 sc.Logger.Info("new channel", "type", newChan.ChannelType(), "extraData", newChan.ExtraData())
187 chanFunc, ok := sc.SSHServer.Config.ChannelMiddleware[newChan.ChannelType()]
188 if !ok {
189 sc.Logger.Info("no channel middleware for type", "type", newChan.ChannelType())
190 continue
191 }
192
193 go func() {
194 err := chanFunc(newChan, sc)
195 if err != nil {
196 sc.Logger.Error("channel middleware", "err", err)
197 }
198 }()
199 case req, ok := <-reqs:
200 if !ok {
201 return nil
202 }
203 sc.Logger.Info("new request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
204 }
205 }
206}
207
208func NewSSHServerConn(
209 ctx context.Context,
210 logger *slog.Logger,
211 conn *ssh.ServerConn,
212 server *SSHServer,
213) *SSHServerConn {
214 if ctx == nil {
215 ctx = context.Background()
216 }
217
218 cancelCtx, cancelFunc := context.WithCancel(ctx)
219
220 if logger == nil {
221 logger = slog.Default()
222 }
223
224 return &SSHServerConn{
225 Ctx: cancelCtx,
226 CancelFunc: cancelFunc,
227 Logger: logger,
228 Conn: conn,
229 SSHServer: server,
230 Start: time.Now(),
231 }
232}
233
234type SSHServerHandler func(*SSHServerConnSession) error
235type SSHServerMiddleware func(SSHServerHandler) SSHServerHandler
236type SSHServerChannelMiddleware func(ssh.NewChannel, *SSHServerConn) error
237
238type SSHServerConfig struct {
239 *ssh.ServerConfig
240 App string
241 ListenAddr string
242 PromListenAddr string
243 Middleware []SSHServerMiddleware
244 SubsystemMiddleware []SSHServerMiddleware
245 ChannelMiddleware map[string]SSHServerChannelMiddleware
246}
247
248type SSHServer struct {
249 Ctx context.Context
250 CancelFunc context.CancelFunc
251 Logger *slog.Logger
252 Config *SSHServerConfig
253 Listener net.Listener
254 Conns *syncmap.Map[string, *SSHServerConn]
255
256 SessionsCreated *prometheus.CounterVec
257 SessionsFinished *prometheus.CounterVec
258 SessionsDuration *prometheus.CounterVec
259}
260
261func (s *SSHServer) ListenAndServe() error {
262 if s.Config.PromListenAddr != "" {
263 s.SessionsCreated = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
264 Name: "pssh_sessions_created_total",
265 Help: "The total number of sessions created",
266 ConstLabels: prometheus.Labels{
267 "app": s.Config.App,
268 },
269 }, []string{"command"})
270
271 s.SessionsFinished = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
272 Name: "pssh_sessions_finished_total",
273 Help: "The total number of sessions created",
274 ConstLabels: prometheus.Labels{
275 "app": s.Config.App,
276 },
277 }, []string{"command"})
278
279 s.SessionsDuration = promauto.With(prometheus.DefaultRegisterer).NewCounterVec(prometheus.CounterOpts{
280 Name: "pssh_sessions_duration_seconds",
281 Help: "The total sessions duration in seconds",
282 ConstLabels: prometheus.Labels{
283 "app": s.Config.App,
284 },
285 }, []string{"command"})
286
287 go func() {
288 mux := http.NewServeMux()
289 mux.Handle("/metrics", promhttp.Handler())
290
291 srv := &http.Server{Addr: s.Config.PromListenAddr, Handler: mux}
292
293 go func() {
294 <-s.Ctx.Done()
295 s.Logger.Info("Prometheus server shutting down")
296 srv.Close()
297 }()
298
299 s.Logger.Info("Starting Prometheus server", "addr", s.Config.PromListenAddr)
300
301 err := srv.ListenAndServe()
302 if err != nil {
303 if errors.Is(err, http.ErrServerClosed) {
304 s.Logger.Info("Prometheus server shut down")
305 return
306 }
307
308 s.Logger.Error("Prometheus serve error", "err", err)
309 panic(err)
310 }
311 }()
312 }
313
314 listen, err := net.Listen("tcp", s.Config.ListenAddr)
315 if err != nil {
316 return err
317 }
318
319 s.Listener = listen
320 defer s.Listener.Close()
321
322 go func() {
323 <-s.Ctx.Done()
324 s.Close()
325 }()
326
327 var retErr error
328
329 for {
330 conn, err := s.Listener.Accept()
331 if err != nil {
332 s.Logger.Error("accept", "err", err)
333 if errors.Is(err, net.ErrClosed) {
334 retErr = err
335 break
336 }
337 continue
338 }
339
340 go func() {
341 if err := s.HandleConn(conn); err != nil && !errors.Is(err, io.EOF) {
342 s.Logger.Error("Error handling connection", "err", err, "remoteAddr", conn.RemoteAddr().String())
343 }
344 }()
345 }
346
347 if errors.Is(retErr, net.ErrClosed) {
348 return nil
349 }
350
351 return retErr
352}
353
354func (s *SSHServer) HandleConn(conn net.Conn) error {
355 defer conn.Close()
356
357 sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config.ServerConfig)
358 if err != nil {
359 return err
360 }
361
362 newLogger := s.Logger.With(
363 "remoteAddr", conn.RemoteAddr().String(),
364 "user", sshConn.User(),
365 "pubkey", sshConn.Permissions.Extensions["pubkey"],
366 )
367
368 newConn := NewSSHServerConn(
369 s.Ctx,
370 newLogger,
371 sshConn,
372 s,
373 )
374
375 s.Conns.Store(sshConn.RemoteAddr().String(), newConn)
376
377 err = newConn.Handle(chans, reqs)
378
379 s.Conns.Delete(sshConn.RemoteAddr().String())
380
381 return err
382}
383
384func (s *SSHServer) Close() error {
385 s.CancelFunc()
386 return s.Listener.Close()
387}
388
389func NewSSHServer(ctx context.Context, logger *slog.Logger, config *SSHServerConfig) *SSHServer {
390 if ctx == nil {
391 ctx = context.Background()
392 }
393
394 cancelCtx, cancelFunc := context.WithCancel(ctx)
395
396 if logger == nil {
397 logger = slog.Default()
398 }
399
400 if config == nil {
401 config = &SSHServerConfig{}
402 }
403
404 if config.ChannelMiddleware == nil {
405 config.ChannelMiddleware = map[string]SSHServerChannelMiddleware{}
406 }
407
408 if _, ok := config.ChannelMiddleware["session"]; !ok {
409 config.ChannelMiddleware["session"] = func(newChan ssh.NewChannel, sc *SSHServerConn) error {
410 channel, requests, err := newChan.Accept()
411 if err != nil {
412 sc.Logger.Error("accept session channel", "err", err)
413 return err
414 }
415
416 ctx, cancelFunc := context.WithCancel(sc.Ctx)
417
418 sesh := &SSHServerConnSession{
419 Channel: channel,
420 SSHServerConn: sc,
421 Ctx: ctx,
422 CancelFunc: cancelFunc,
423 }
424
425 for {
426 select {
427 case <-sesh.Done():
428 return nil
429 case req, ok := <-requests:
430 if !ok {
431 return nil
432 }
433
434 go func() {
435 sc.Logger.Info("new session request", "type", req.Type, "wantReply", req.WantReply, "payload", req.Payload)
436 switch req.Type {
437 case "subsystem":
438 if len(sc.SSHServer.Config.SubsystemMiddleware) == 0 {
439 err := req.Reply(false, nil)
440 if err != nil {
441 sc.Logger.Error("subsystem reply", "err", err)
442 }
443
444 err = sc.Close()
445 if err != nil {
446 sc.Logger.Error("subsystem close", "err", err)
447 }
448
449 sesh.Fatal(err)
450 return
451 }
452
453 h := func(*SSHServerConnSession) error { return nil }
454 for _, m := range sc.SSHServer.Config.SubsystemMiddleware {
455 h = m(h)
456 }
457
458 err := req.Reply(true, nil)
459 if err != nil {
460 sc.Logger.Error("subsystem reply", "err", err)
461 sesh.Fatal(err)
462 return
463 }
464
465 if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
466 sc.Logger.Error("subsystem middleware", "err", err)
467 sesh.Fatal(err)
468 return
469 }
470
471 err = sesh.Exit(0)
472 if err != nil {
473 sc.Logger.Error("subsystem exit", "err", err)
474 }
475
476 err = sesh.Close()
477 if err != nil {
478 sc.Logger.Error("subsystem close", "err", err)
479 }
480 case "shell", "exec":
481 if len(sc.SSHServer.Config.Middleware) == 0 {
482 err := req.Reply(false, nil)
483 if err != nil {
484 sc.Logger.Error("shell/exec reply", "err", err)
485 }
486 sesh.Fatal(err)
487 return
488 }
489
490 command := "shell"
491
492 if len(req.Payload) > 0 {
493 var payload = struct{ Value string }{}
494 err := ssh.Unmarshal(req.Payload, &payload)
495 if err != nil {
496 sc.Logger.Error("shell/exec unmarshal", "err", err)
497 sesh.Fatal(err)
498 return
499 }
500
501 command = payload.Value
502
503 sesh.SetValue("command", strings.Fields(payload.Value))
504 }
505
506 if !utf8.ValidString(command) {
507 command = base64.StdEncoding.EncodeToString([]byte(command))
508 }
509
510 if sc.SSHServer.Config.PromListenAddr != "" {
511 sc.SSHServer.SessionsCreated.WithLabelValues(command).Inc()
512 defer func() {
513 sc.SSHServer.SessionsFinished.WithLabelValues(command).Inc()
514 sc.SSHServer.SessionsDuration.WithLabelValues(command).Add(time.Since(sc.Start).Seconds())
515 }()
516 }
517
518 h := func(*SSHServerConnSession) error { return nil }
519 for _, m := range sc.SSHServer.Config.Middleware {
520 h = m(h)
521 }
522
523 err = req.Reply(true, nil)
524 if err != nil {
525 sc.Logger.Error("shell/exec reply", "err", err)
526 sesh.Fatal(err)
527 return
528 }
529
530 if err := h(sesh); err != nil && !errors.Is(err, io.EOF) {
531 sc.Logger.Error("exec middleware", "err", err)
532 sesh.Fatal(err)
533 return
534 }
535
536 err = sesh.Exit(0)
537 if err != nil {
538 sc.Logger.Error("subsystem exit", "err", err)
539 }
540
541 err = sesh.Close()
542 if err != nil {
543 sc.Logger.Error("subsystem close", "err", err)
544 }
545 case "pty-req":
546 sesh.mu.Lock()
547 found := sesh.pty != nil
548 sesh.mu.Unlock()
549 if found {
550 err := req.Reply(false, nil)
551 if err != nil {
552 sc.Logger.Error("pty-req reply", "err", err)
553 }
554 return
555 }
556
557 ptyReq, ok := parsePtyRequest(req.Payload)
558 if !ok {
559 err := req.Reply(false, nil)
560 if err != nil {
561 sc.Logger.Error("pty-req reply", "err", err)
562 }
563 return
564 }
565
566 sesh.mu.Lock()
567 sesh.pty = &ptyReq
568 sesh.winch = make(chan Window, 1)
569 sesh.mu.Unlock()
570
571 sesh.winch <- ptyReq.Window
572 err := req.Reply(ok, nil)
573 if err != nil {
574 sc.Logger.Error("pty-req reply", "err", err)
575 }
576 case "window-change":
577 sesh.mu.Lock()
578 found := sesh.pty != nil
579 sesh.mu.Unlock()
580
581 if !found {
582 err := req.Reply(false, nil)
583 if err != nil {
584 sc.Logger.Error("pty-req reply", "err", err)
585 }
586 return
587 }
588
589 win, ok := parseWinchRequest(req.Payload)
590 if ok {
591 sesh.mu.Lock()
592 sesh.pty.Window = win
593 sesh.winch <- win
594 sesh.mu.Unlock()
595 }
596
597 err := req.Reply(ok, nil)
598 if err != nil {
599 sc.Logger.Error("window-change reply", "err", err)
600 }
601 }
602 }()
603 }
604 }
605 }
606 }
607
608 server := &SSHServer{
609 Ctx: cancelCtx,
610 CancelFunc: cancelFunc,
611 Logger: logger,
612 Config: config,
613 Conns: syncmap.New[string, *SSHServerConn](),
614 }
615
616 return server
617}
618
619type PubKeyAuthHandler func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
620
621func NewSSHServerWithConfig(
622 ctx context.Context,
623 logger *slog.Logger,
624 app, host, port, promPort, hostKey string,
625 pubKeyAuthHandler PubKeyAuthHandler,
626 middleware, subsystemMiddleware []SSHServerMiddleware,
627 channelMiddleware map[string]SSHServerChannelMiddleware,
628) (*SSHServer, error) {
629 server := NewSSHServer(ctx, logger, &SSHServerConfig{
630 App: app,
631 ListenAddr: fmt.Sprintf("%s:%s", host, port),
632 ServerConfig: &ssh.ServerConfig{
633 PublicKeyCallback: pubKeyAuthHandler,
634 },
635 Middleware: middleware,
636 SubsystemMiddleware: subsystemMiddleware,
637 ChannelMiddleware: channelMiddleware,
638 })
639
640 if promPort != "" {
641 server.Config.PromListenAddr = fmt.Sprintf("%s:%s", host, promPort)
642 }
643
644 if hostKey != "" {
645 pemBytes, err := os.ReadFile(hostKey)
646 if err != nil {
647 logger.Error("failed to read private key file", "error", err)
648 if !os.IsNotExist(err) {
649 return nil, err
650 }
651
652 logger.Info("generating new private key")
653
654 pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
655 if err != nil {
656 logger.Error("failed to generate private key", "error", err)
657 return nil, err
658 }
659
660 privb, err := ssh.MarshalPrivateKey(privKey, "")
661 if err != nil {
662 logger.Error("failed to marshal private key", "error", err)
663 return nil, err
664 }
665
666 block := &pem.Block{
667 Type: "OPENSSH PRIVATE KEY",
668 Bytes: privb.Bytes,
669 }
670
671 if err = os.MkdirAll(path.Dir(hostKey), 0700); err != nil {
672 logger.Error("failed to create ssh_data directory", "error", err)
673 return nil, err
674 }
675
676 pemBytes = pem.EncodeToMemory(block)
677
678 if err = os.WriteFile(hostKey, pemBytes, 0600); err != nil {
679 logger.Error("failed to write private key", "error", err)
680 return nil, err
681 }
682
683 sshPubKey, err := ssh.NewPublicKey(pubKey)
684 if err != nil {
685 logger.Error("failed to create public key", "error", err)
686 return nil, err
687 }
688
689 pubb := ssh.MarshalAuthorizedKey(sshPubKey)
690 if err = os.WriteFile(fmt.Sprintf("%s.pub", hostKey), pubb, 0600); err != nil {
691 logger.Error("failed to write public key", "error", err)
692 return nil, err
693 }
694 }
695
696 signer, err := ssh.ParsePrivateKey(pemBytes)
697 if err != nil {
698 logger.Error("failed to parse private key", "error", err)
699 return nil, err
700 }
701
702 server.Config.AddHostKey(signer)
703 }
704
705 return server, nil
706}
707
708func KeysEqual(a, b ssh.PublicKey) bool {
709 if a == nil || b == nil {
710 return false
711 }
712
713 am := a.Marshal()
714 bm := b.Marshal()
715 return (len(am) == len(bm) && subtle.ConstantTimeCompare(am, bm) == 1)
716}