repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pssh
Antonio Mika  ·  2025-04-10

logger.go

  1package pssh
  2
  3import (
  4	"log/slog"
  5	"time"
  6
  7	"github.com/picosh/pico/pkg/db"
  8)
  9
 10type ctxLoggerKey struct{}
 11type ctxUserKey struct{}
 12
 13type FindUserInterface interface {
 14	FindUser(string) (*db.User, error)
 15}
 16
 17type GetLoggerInterface interface {
 18	GetLogger(s *SSHServerConnSession) *slog.Logger
 19}
 20
 21func LogMiddleware(getLogger GetLoggerInterface, db FindUserInterface) SSHServerMiddleware {
 22	return func(sshHandler SSHServerHandler) SSHServerHandler {
 23		return func(s *SSHServerConnSession) error {
 24			ct := time.Now()
 25
 26			logger := GetLogger(s)
 27			if logger == slog.Default() {
 28				logger = getLogger.GetLogger(s)
 29
 30				user := GetUser(s)
 31				if user == nil {
 32					userID, ok := s.Permissions().Extensions["user_id"]
 33					if ok {
 34						user, err := db.FindUser(userID)
 35						if err == nil && user != nil {
 36							logger = logger.With(
 37								"user", user.Name,
 38								"userId", user.ID,
 39								"ip", s.RemoteAddr().String(),
 40							)
 41							s.SetValue(ctxUserKey{}, user)
 42						}
 43					} else {
 44						logger.Error("`user_id` not set in permissions")
 45					}
 46
 47				}
 48
 49				s.SetValue(ctxLoggerKey{}, logger)
 50			}
 51
 52			pty, _, ok := s.Pty()
 53
 54			width, height := 0, 0
 55			term := ""
 56			if pty != nil {
 57				term = pty.Term
 58				width = pty.Window.Width
 59				height = pty.Window.Height
 60			}
 61
 62			logger.Info(
 63				"connect",
 64				"sshUser", s.User(),
 65				"pty", ok,
 66				"term", term,
 67				"windowWidth", width,
 68				"windowHeight", height,
 69			)
 70
 71			err := sshHandler(s)
 72			if err != nil {
 73				logger.Error("error", "err", err)
 74			}
 75
 76			if pty != nil {
 77				term = pty.Term
 78				width = pty.Window.Width
 79				height = pty.Window.Height
 80			}
 81
 82			logger.Info(
 83				"disconnect",
 84				"sshUser", s.User(),
 85				"pty", ok,
 86				"term", term,
 87				"windowWidth", width,
 88				"windowHeight", height,
 89				"duration", time.Since(ct),
 90				"err", err,
 91			)
 92
 93			return err
 94		}
 95	}
 96}
 97
 98func GetLogger(s *SSHServerConnSession) *slog.Logger {
 99	logger := slog.Default()
100	if s == nil {
101		return logger
102	}
103
104	if v, ok := s.Context().Value(ctxLoggerKey{}).(*slog.Logger); ok {
105		return v
106	}
107
108	return logger
109}
110
111func GetUser(s *SSHServerConnSession) *db.User {
112	if v, ok := s.Context().Value(ctxUserKey{}).(*db.User); ok {
113		return v
114	}
115
116	return nil
117}