repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / shared
Eric Bower  ·  2026-01-25

ssh.go

  1package shared
  2
  3import (
  4	"fmt"
  5	"log/slog"
  6	"strings"
  7	"time"
  8
  9	"github.com/picosh/pico/pkg/db"
 10	"golang.org/x/crypto/ssh"
 11)
 12
 13const adminPrefix = "admin__"
 14
 15type SshAuthHandler struct {
 16	DB        AuthFindUser
 17	Logger    *slog.Logger
 18	Principal string
 19}
 20
 21type AuthFindUser interface {
 22	FindUserByPubkey(key string) (*db.User, error)
 23	FindUserByName(name string) (*db.User, error)
 24	FindFeature(userID, name string) (*db.FeatureFlag, error)
 25	InsertAccessLog(log *db.AccessLog) error
 26}
 27
 28func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger, principal string) *SshAuthHandler {
 29	return &SshAuthHandler{
 30		DB:        dbh,
 31		Logger:    logger,
 32		Principal: principal,
 33	}
 34}
 35
 36type AuthedPubkey struct {
 37	OrigPubkey string
 38	Pubkey     string
 39	Identity   string
 40}
 41
 42func PubkeyCertVerify(key ssh.PublicKey, srcPrincipal string) (*AuthedPubkey, error) {
 43	origPubkey := KeyForKeyText(key)
 44	authed := &AuthedPubkey{
 45		OrigPubkey: origPubkey,
 46		Pubkey:     origPubkey,
 47		Identity:   "pubkey",
 48	}
 49
 50	cert, ok := key.(*ssh.Certificate)
 51	if ok {
 52		if cert.CertType != ssh.UserCert {
 53			return nil, fmt.Errorf("ssh-cert has type %d", cert.CertType)
 54		}
 55
 56		found := false
 57		for _, princ := range cert.ValidPrincipals {
 58			if princ == "admin" || princ == srcPrincipal {
 59				found = true
 60				break
 61			}
 62		}
 63		if !found {
 64			return nil, fmt.Errorf("ssh-cert principals not valid")
 65		}
 66
 67		clock := time.Now
 68		unixNow := clock().Unix()
 69		if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
 70			return nil, fmt.Errorf("ssh-cert is not yet valid")
 71		}
 72		if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
 73			return nil, fmt.Errorf("ssh-cert has expired")
 74		}
 75
 76		authed.Pubkey = KeyForKeyText(cert.SignatureKey)
 77		authed.Identity = cert.KeyId
 78		return authed, nil
 79	}
 80
 81	return authed, nil
 82}
 83
 84func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
 85	log := r.Logger
 86	var user *db.User
 87	var err error
 88	authed, err := PubkeyCertVerify(key, r.Principal)
 89	if err != nil {
 90		return nil, err
 91	}
 92
 93	user, err = r.DB.FindUserByPubkey(authed.Pubkey)
 94	if err != nil {
 95		log.Error(
 96			"could not find user for key",
 97			"keyType", key.Type(),
 98			"key", string(key.Marshal()),
 99			"err", err,
100		)
101		return nil, err
102	}
103
104	if user.Name == "" {
105		log.Error("username is not set")
106		return nil, fmt.Errorf("username is not set")
107	}
108
109	if authed.Identity == "public" && user.PublicKey != nil && user.PublicKey.Name != "" {
110		authed.Identity = user.PublicKey.Name
111	}
112
113	log.Info("inserting access log", "principal", r.Principal, "identity", authed.Identity)
114	err = r.DB.InsertAccessLog(&db.AccessLog{
115		UserID:   user.ID,
116		Service:  r.Principal,
117		Identity: authed.Identity,
118		Pubkey:   authed.OrigPubkey,
119	})
120	if err != nil {
121		log.Error("cannot insert access log", "err", err)
122	}
123
124	// impersonation
125	var impID string
126	usr := conn.User()
127	if strings.HasPrefix(usr, adminPrefix) {
128		ff, err := r.DB.FindFeature(user.ID, "admin")
129		if err == nil && ff.IsValid() {
130			impersonate := strings.TrimPrefix(usr, adminPrefix)
131			impersonatedUser, err := r.DB.FindUserByName(impersonate)
132			if err == nil {
133				impID = user.ID
134				user = impersonatedUser
135			}
136		}
137	}
138
139	perms := &ssh.Permissions{
140		Extensions: map[string]string{
141			"user_id":  user.ID,
142			"pubkey":   authed.Pubkey,
143			"identity": authed.Identity,
144		},
145	}
146
147	if impID != "" {
148		perms.Extensions["imp_id"] = impID
149	}
150
151	return perms, nil
152}
153
154func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag {
155	ff, _ := dbpool.FindFeature(userID, "plus")
156	// we have free tiers so users might not have a feature flag
157	// in which case we set sane defaults
158	if ff == nil {
159		ff = db.NewFeatureFlag(
160			userID,
161			"plus",
162			cfg.MaxSize,
163			cfg.MaxAssetSize,
164			cfg.MaxSpecialFileSize,
165		)
166	}
167	// this is jank
168	ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize)
169	ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize)
170	ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize)
171	return ff
172}