repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pssh
Antonio Mika  ·  2025-04-30

logger.go

  1package pssh
  2
  3import (
  4	"log/slog"
  5	"time"
  6
  7	"github.com/picosh/pico/pkg/db"
  8)
  9
 10type ctxLoggerKey struct{}
 11type ctxUserKey struct{}
 12
 13type FindUserInterface interface {
 14	FindUser(string) (*db.User, error)
 15	FindUserByPubkey(string) (*db.User, error)
 16}
 17
 18type GetLoggerInterface interface {
 19	GetLogger(s *SSHServerConnSession) *slog.Logger
 20}
 21
 22func LogMiddleware(getLogger GetLoggerInterface, database FindUserInterface) SSHServerMiddleware {
 23	return func(sshHandler SSHServerHandler) SSHServerHandler {
 24		return func(s *SSHServerConnSession) error {
 25			ct := time.Now()
 26
 27			logger := GetLogger(s)
 28			if logger == slog.Default() || logger == s.Logger {
 29				logger = getLogger.GetLogger(s)
 30
 31				user := GetUser(s)
 32				if user == nil {
 33					_, impersonated := s.Permissions().Extensions["imp_id"]
 34
 35					var user *db.User
 36					var err error
 37					var found bool
 38
 39					if !impersonated {
 40						pubKey, ok := s.Permissions().Extensions["pubkey"]
 41						if ok {
 42							user, err = database.FindUserByPubkey(pubKey)
 43							found = true
 44						}
 45					} else {
 46						userID, ok := s.Permissions().Extensions["user_id"]
 47						if ok {
 48							user, err = database.FindUser(userID)
 49							found = true
 50						}
 51					}
 52
 53					if found {
 54						if err == nil && user != nil {
 55							logger = logger.With(
 56								"user", user.Name,
 57								"userId", user.ID,
 58								"ip", s.RemoteAddr().String(),
 59							)
 60
 61							SetUser(s, user)
 62						} else {
 63							logger.Error("`user` not set in permissions", "err", err)
 64						}
 65					}
 66				}
 67
 68				SetLogger(s, logger)
 69			}
 70
 71			pty, _, ok := s.Pty()
 72
 73			width, height := 0, 0
 74			term := ""
 75			if pty != nil {
 76				term = pty.Term
 77				width = pty.Window.Width
 78				height = pty.Window.Height
 79			}
 80
 81			logger.Info(
 82				"connect",
 83				"sshUser", s.User(),
 84				"pty", ok,
 85				"term", term,
 86				"windowWidth", width,
 87				"windowHeight", height,
 88			)
 89
 90			err := sshHandler(s)
 91			if err != nil {
 92				logger.Error("error", "err", err)
 93			}
 94
 95			if pty != nil {
 96				term = pty.Term
 97				width = pty.Window.Width
 98				height = pty.Window.Height
 99			}
100
101			logger.Info(
102				"disconnect",
103				"sshUser", s.User(),
104				"pty", ok,
105				"term", term,
106				"windowWidth", width,
107				"windowHeight", height,
108				"duration", time.Since(ct),
109				"err", err,
110			)
111
112			return err
113		}
114	}
115}
116
117func GetLogger(s *SSHServerConnSession) *slog.Logger {
118	logger := slog.Default()
119	if s == nil {
120		return logger
121	}
122
123	logger = s.Logger
124
125	if v, ok := s.Context().Value(ctxLoggerKey{}).(*slog.Logger); ok {
126		return v
127	}
128
129	return logger
130}
131
132func SetLogger(s *SSHServerConnSession, logger *slog.Logger) {
133	if s == nil {
134		return
135	}
136
137	s.SetValue(ctxLoggerKey{}, logger)
138}
139
140func GetUser(s *SSHServerConnSession) *db.User {
141	if v, ok := s.Context().Value(ctxUserKey{}).(*db.User); ok {
142		return v
143	}
144
145	return nil
146}
147
148func SetUser(s *SSHServerConnSession, user *db.User) {
149	if s == nil {
150		return
151	}
152
153	s.SetValue(ctxUserKey{}, user)
154}