Eric Bower
·
2026-01-25
ssh.go
1package shared
2
3import (
4 "fmt"
5 "log/slog"
6 "strings"
7 "time"
8
9 "github.com/picosh/pico/pkg/db"
10 "golang.org/x/crypto/ssh"
11)
12
13const adminPrefix = "admin__"
14
15type SshAuthHandler struct {
16 DB AuthFindUser
17 Logger *slog.Logger
18 Principal string
19}
20
21type AuthFindUser interface {
22 FindUserByPubkey(key string) (*db.User, error)
23 FindUserByName(name string) (*db.User, error)
24 FindFeature(userID, name string) (*db.FeatureFlag, error)
25 InsertAccessLog(log *db.AccessLog) error
26}
27
28func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger, principal string) *SshAuthHandler {
29 return &SshAuthHandler{
30 DB: dbh,
31 Logger: logger,
32 Principal: principal,
33 }
34}
35
36type AuthedPubkey struct {
37 OrigPubkey string
38 Pubkey string
39 Identity string
40}
41
42func PubkeyCertVerify(key ssh.PublicKey, srcPrincipal string) (*AuthedPubkey, error) {
43 origPubkey := KeyForKeyText(key)
44 authed := &AuthedPubkey{
45 OrigPubkey: origPubkey,
46 Pubkey: origPubkey,
47 Identity: "pubkey",
48 }
49
50 cert, ok := key.(*ssh.Certificate)
51 if ok {
52 if cert.CertType != ssh.UserCert {
53 return nil, fmt.Errorf("ssh-cert has type %d", cert.CertType)
54 }
55
56 found := false
57 for _, princ := range cert.ValidPrincipals {
58 if princ == "admin" || princ == srcPrincipal {
59 found = true
60 break
61 }
62 }
63 if !found {
64 return nil, fmt.Errorf("ssh-cert principals not valid")
65 }
66
67 clock := time.Now
68 unixNow := clock().Unix()
69 if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
70 return nil, fmt.Errorf("ssh-cert is not yet valid")
71 }
72 if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
73 return nil, fmt.Errorf("ssh-cert has expired")
74 }
75
76 authed.Pubkey = KeyForKeyText(cert.SignatureKey)
77 authed.Identity = cert.KeyId
78 return authed, nil
79 }
80
81 return authed, nil
82}
83
84func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
85 log := r.Logger
86 var user *db.User
87 var err error
88 authed, err := PubkeyCertVerify(key, r.Principal)
89 if err != nil {
90 return nil, err
91 }
92
93 user, err = r.DB.FindUserByPubkey(authed.Pubkey)
94 if err != nil {
95 log.Error(
96 "could not find user for key",
97 "keyType", key.Type(),
98 "key", string(key.Marshal()),
99 "err", err,
100 )
101 return nil, err
102 }
103
104 if user.Name == "" {
105 log.Error("username is not set")
106 return nil, fmt.Errorf("username is not set")
107 }
108
109 if authed.Identity == "public" && user.PublicKey != nil && user.PublicKey.Name != "" {
110 authed.Identity = user.PublicKey.Name
111 }
112
113 log.Info("inserting access log", "principal", r.Principal, "identity", authed.Identity)
114 err = r.DB.InsertAccessLog(&db.AccessLog{
115 UserID: user.ID,
116 Service: r.Principal,
117 Identity: authed.Identity,
118 Pubkey: authed.OrigPubkey,
119 })
120 if err != nil {
121 log.Error("cannot insert access log", "err", err)
122 }
123
124 // impersonation
125 var impID string
126 usr := conn.User()
127 if strings.HasPrefix(usr, adminPrefix) {
128 ff, err := r.DB.FindFeature(user.ID, "admin")
129 if err == nil && ff.IsValid() {
130 impersonate := strings.TrimPrefix(usr, adminPrefix)
131 impersonatedUser, err := r.DB.FindUserByName(impersonate)
132 if err == nil {
133 impID = user.ID
134 user = impersonatedUser
135 }
136 }
137 }
138
139 perms := &ssh.Permissions{
140 Extensions: map[string]string{
141 "user_id": user.ID,
142 "pubkey": authed.Pubkey,
143 "identity": authed.Identity,
144 },
145 }
146
147 if impID != "" {
148 perms.Extensions["imp_id"] = impID
149 }
150
151 return perms, nil
152}
153
154func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag {
155 ff, _ := dbpool.FindFeature(userID, "plus")
156 // we have free tiers so users might not have a feature flag
157 // in which case we set sane defaults
158 if ff == nil {
159 ff = db.NewFeatureFlag(
160 userID,
161 "plus",
162 cfg.MaxSize,
163 cfg.MaxAssetSize,
164 cfg.MaxSpecialFileSize,
165 )
166 }
167 // this is jank
168 ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize)
169 ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize)
170 ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize)
171 return ff
172}