- commit
- 461e887
- parent
- 9d5a8dd
- author
- Eric Bower
- date
- 2025-03-28 11:53:44 -0400 EDT
feat(pgs): admins can impersonate This change will let admins impersonate any user for the pgs cli
14 files changed,
+64,
-52
+3,
-3
1@@ -303,7 +303,7 @@ func userHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
2 "publicKey", data.PublicKey,
3 )
4
5- user, err := apiConfig.Dbpool.FindUserForName(data.Username)
6+ user, err := apiConfig.Dbpool.FindUserByName(data.Username)
7 if err != nil {
8 apiConfig.Cfg.Logger.Error(err.Error())
9 http.Error(w, err.Error(), http.StatusNotFound)
10@@ -461,7 +461,7 @@ func paymentWebhookHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
11 status := event.Data.Attr.Status
12 txID := fmt.Sprint(event.Data.Attr.OrderNumber)
13
14- user, err := apiConfig.Dbpool.FindUserForName(username)
15+ user, err := apiConfig.Dbpool.FindUserByName(username)
16 if err != nil {
17 logger.Error("no user found with username", "username", username)
18 w.WriteHeader(http.StatusOK)
19@@ -624,7 +624,7 @@ func deserializeCaddyAccessLog(dbpool db.DB, access *AccessLog) (*db.AnalyticsVi
20 }
21
22 // get user ID
23- user, err := dbpool.FindUserForName(props.Username)
24+ user, err := dbpool.FindUserByName(props.Username)
25 if err != nil {
26 return nil, fmt.Errorf("could not find user for name %s: %w", props.Username, err)
27 }
+2,
-2
1@@ -220,7 +220,7 @@ func (a *AuthDb) AddPicoPlusUser(username, email, from, txid string) error {
2 return nil
3 }
4
5-func (a *AuthDb) FindUserForName(username string) (*db.User, error) {
6+func (a *AuthDb) FindUserByName(username string) (*db.User, error) {
7 return &db.User{ID: testUserID, Name: username}, nil
8 }
9
10@@ -243,7 +243,7 @@ func (a *AuthDb) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
11 return []*db.PublicKey{{ID: "1", UserID: user.ID, Name: "my-key", Key: "nice-pubkey", CreatedAt: &time.Time{}}}, nil
12 }
13
14-func (a *AuthDb) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
15+func (a *AuthDb) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
16 now := time.Date(2021, 8, 15, 14, 30, 45, 100, time.UTC)
17 oneDayWarning := now.AddDate(0, 0, 1)
18 return &db.FeatureFlag{ID: "2", UserID: userID, Name: "plus", ExpiresAt: &oneDayWarning, CreatedAt: &now}, nil
+3,
-3
1@@ -78,7 +78,7 @@ func blogHandler(w http.ResponseWriter, r *http.Request) {
2 logger := blogger.With("user", username)
3 cfg := shared.GetCfg(r)
4
5- user, err := dbpool.FindUserForName(username)
6+ user, err := dbpool.FindUserByName(username)
7 if err != nil {
8 logger.Info("user not found")
9 http.Error(w, "user not found", http.StatusNotFound)
10@@ -170,7 +170,7 @@ func postHandler(w http.ResponseWriter, r *http.Request) {
11 blogger := shared.GetLogger(r)
12 logger := blogger.With("slug", slug, "user", username)
13
14- user, err := dbpool.FindUserForName(username)
15+ user, err := dbpool.FindUserByName(username)
16 if err != nil {
17 logger.Info("paste not found")
18 http.Error(w, "paste not found", http.StatusNotFound)
19@@ -271,7 +271,7 @@ func postHandlerRaw(w http.ResponseWriter, r *http.Request) {
20 blogger := shared.GetLogger(r)
21 logger := blogger.With("user", username, "slug", slug)
22
23- user, err := dbpool.FindUserForName(username)
24+ user, err := dbpool.FindUserByName(username)
25 if err != nil {
26 logger.Info("user not found")
27 http.Error(w, "user not found", http.StatusNotFound)
+4,
-16
1@@ -10,26 +10,14 @@ import (
2 "github.com/picosh/pico/pkg/db"
3 "github.com/picosh/pico/pkg/pssh"
4 sendutils "github.com/picosh/pico/pkg/send/utils"
5- "github.com/picosh/utils"
6 )
7
8 func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
9- if s.PublicKey() == nil {
10- return nil, fmt.Errorf("key not found")
11+ userID, ok := s.Conn.Permissions.Extensions["user_id"]
12+ if !ok {
13+ return nil, fmt.Errorf("`user_id` extension not found")
14 }
15-
16- key := utils.KeyForKeyText(s.PublicKey())
17-
18- user, err := dbpool.FindUserByPubkey(key)
19- if err != nil {
20- return nil, err
21- }
22-
23- if user.Name == "" {
24- return nil, fmt.Errorf("must have username set")
25- }
26-
27- return user, nil
28+ return dbpool.FindUser(userID)
29 }
30
31 type arrayFlags []string
+2,
-2
1@@ -117,10 +117,10 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
2 return err
3 }
4
5- ff, err := dbpool.FindFeatureForUser(user.ID, "plus")
6+ ff, err := dbpool.FindFeature(user.ID, "plus")
7 if err != nil {
8 handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
9- ff, err = dbpool.FindFeatureForUser(user.ID, "bouncer")
10+ ff, err = dbpool.FindFeature(user.ID, "bouncer")
11 if err != nil {
12 handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
13 sesh.Fatal(err)
+6,
-6
1@@ -125,7 +125,7 @@ func blogStyleHandler(w http.ResponseWriter, r *http.Request) {
2 logger := shared.GetLogger(r)
3 cfg := shared.GetCfg(r)
4
5- user, err := dbpool.FindUserForName(username)
6+ user, err := dbpool.FindUserByName(username)
7 if err != nil {
8 logger.Info("blog not found", "user", username)
9 http.Error(w, "blog not found", http.StatusNotFound)
10@@ -155,7 +155,7 @@ func blogHandler(w http.ResponseWriter, r *http.Request) {
11 logger := shared.GetLogger(r)
12 cfg := shared.GetCfg(r)
13
14- user, err := dbpool.FindUserForName(username)
15+ user, err := dbpool.FindUserByName(username)
16 if err != nil {
17 logger.Info("blog not found", "user", username)
18 http.Error(w, "blog not found", http.StatusNotFound)
19@@ -301,7 +301,7 @@ func postRawHandler(w http.ResponseWriter, r *http.Request) {
20 logger := shared.GetLogger(r)
21 logger = logger.With("slug", slug)
22
23- user, err := dbpool.FindUserForName(username)
24+ user, err := dbpool.FindUserByName(username)
25 if err != nil {
26 logger.Info("blog not found", "user", username)
27 http.Error(w, "blog not found", http.StatusNotFound)
28@@ -341,7 +341,7 @@ func postHandler(w http.ResponseWriter, r *http.Request) {
29 dbpool := shared.GetDB(r)
30 logger := shared.GetLogger(r)
31
32- user, err := dbpool.FindUserForName(username)
33+ user, err := dbpool.FindUserByName(username)
34 if err != nil {
35 logger.Info("blog not found", "user", username)
36 http.Error(w, "blog not found", http.StatusNotFound)
37@@ -589,7 +589,7 @@ func rssBlogHandler(w http.ResponseWriter, r *http.Request) {
38 logger := shared.GetLogger(r)
39 cfg := shared.GetCfg(r)
40
41- user, err := dbpool.FindUserForName(username)
42+ user, err := dbpool.FindUserByName(username)
43 if err != nil {
44 logger.Info("rss feed not found", "user", username)
45 http.Error(w, "rss feed not found", http.StatusNotFound)
46@@ -852,7 +852,7 @@ func imgRequest(w http.ResponseWriter, r *http.Request) {
47 logger := shared.GetLogger(r)
48 dbpool := shared.GetDB(r)
49 username := shared.GetUsernameFromRequest(r)
50- user, err := dbpool.FindUserForName(username)
51+ user, err := dbpool.FindUserByName(username)
52 if err != nil {
53 logger.Error("could not find user", "username", username)
54 http.Error(w, "could find user", http.StatusNotFound)
+3,
-3
1@@ -212,7 +212,7 @@ type Token struct {
2 type FeatureFlag struct {
3 ID string `json:"id" db:"id"`
4 UserID string `json:"user_id" db:"user_id"`
5- PaymentHistoryID string `json:"payment_history_id" db:"payment_history_id"`
6+ PaymentHistoryID sql.NullString `json:"payment_history_id" db:"payment_history_id"`
7 Name string `json:"name" db:"name"`
8 CreatedAt *time.Time `json:"created_at" db:"created_at"`
9 ExpiresAt *time.Time `json:"expires_at" db:"expires_at"`
10@@ -370,7 +370,7 @@ type DB interface {
11 RemoveKeys(pubkeyIDs []string) error
12
13 FindUsers() ([]*User, error)
14- FindUserForName(name string) (*User, error)
15+ FindUserByName(name string) (*User, error)
16 FindUserForNameAndKey(name string, pubkey string) (*User, error)
17 FindUserForKey(name string, pubkey string) (*User, error)
18 FindUserByPubkey(pubkey string) (*User, error)
19@@ -414,7 +414,7 @@ type DB interface {
20 FindVisitSiteList(opts *SummaryOpts) ([]*VisitUrl, error)
21
22 AddPicoPlusUser(username, email, paymentType, txId string) error
23- FindFeatureForUser(userID string, feature string) (*FeatureFlag, error)
24+ FindFeature(userID string, feature string) (*FeatureFlag, error)
25 FindFeaturesForUser(userID string) ([]*FeatureFlag, error)
26 HasFeatureForUser(userID string, feature string) bool
27 FindTotalSizeForUser(userID string) (int, error)
+9,
-9
1@@ -617,14 +617,14 @@ func (me *PsqlDB) ValidateName(name string) (bool, error) {
2 if !v {
3 return false, fmt.Errorf("%s is invalid: %w", lower, db.ErrNameInvalid)
4 }
5- user, _ := me.FindUserForName(lower)
6+ user, _ := me.FindUserByName(lower)
7 if user == nil {
8 return true, nil
9 }
10 return false, fmt.Errorf("%s already taken: %w", lower, db.ErrNameTaken)
11 }
12
13-func (me *PsqlDB) FindUserForName(name string) (*db.User, error) {
14+func (me *PsqlDB) FindUserByName(name string) (*db.User, error) {
15 user := &db.User{}
16 r := me.Db.QueryRow(sqlSelectUserForName, strings.ToLower(name))
17 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
18@@ -1457,7 +1457,7 @@ func (me *PsqlDB) FindTagsForPost(postID string) ([]string, error) {
19 return tags, nil
20 }
21
22-func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
23+func (me *PsqlDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
24 ff := &db.FeatureFlag{}
25 // payment history is allowed to be null
26 // https://devtidbits.com/2020/08/03/go-sql-error-converting-null-to-string-is-unsupported/
27@@ -1475,7 +1475,7 @@ func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.Feature
28 return nil, err
29 }
30
31- ff.PaymentHistoryID = paymentHistoryID.String
32+ ff.PaymentHistoryID = paymentHistoryID
33
34 return ff, nil
35 }
36@@ -1507,7 +1507,7 @@ func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error)
37 if err != nil {
38 return features, err
39 }
40- ff.PaymentHistoryID = paymentHistoryID.String
41+ ff.PaymentHistoryID = paymentHistoryID
42
43 features = append(features, ff)
44 }
45@@ -1518,7 +1518,7 @@ func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error)
46 }
47
48 func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
49- ff, err := me.FindFeatureForUser(userID, feature)
50+ ff, err := me.FindFeature(userID, feature)
51 if err != nil {
52 return false
53 }
54@@ -1695,7 +1695,7 @@ func (me *PsqlDB) InsertFeature(userID, name string, expiresAt time.Time) (*db.F
55 return nil, err
56 }
57
58- feature, err := me.FindFeatureForUser(userID, name)
59+ feature, err := me.FindFeature(userID, name)
60 if err != nil {
61 return nil, err
62 }
63@@ -1709,7 +1709,7 @@ func (me *PsqlDB) RemoveFeature(userID string, name string) error {
64 }
65
66 func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
67- ff, _ := me.FindFeatureForUser(userID, name)
68+ ff, _ := me.FindFeature(userID, name)
69 if ff == nil {
70 t := time.Now()
71 return t.AddDate(1, 0, 0)
72@@ -1718,7 +1718,7 @@ func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
73 }
74
75 func (me *PsqlDB) AddPicoPlusUser(username, email, paymentType, txId string) error {
76- user, err := me.FindUserForName(username)
77+ user, err := me.FindUserByName(username)
78 if err != nil {
79 return err
80 }
+2,
-2
1@@ -77,7 +77,7 @@ func (me *StubDB) ValidateName(name string) (bool, error) {
2 return false, notImpl
3 }
4
5-func (me *StubDB) FindUserForName(name string) (*db.User, error) {
6+func (me *StubDB) FindUserByName(name string) (*db.User, error) {
7 return nil, notImpl
8 }
9
10@@ -189,7 +189,7 @@ func (me *StubDB) FindTagsForPost(postID string) ([]string, error) {
11 return []string{}, notImpl
12 }
13
14-func (me *StubDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
15+func (me *StubDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
16 return nil, notImpl
17 }
18
1@@ -76,7 +76,7 @@ func CheckHandler(w http.ResponseWriter, r *http.Request) {
2 if !strings.Contains(hostDomain, appDomain) {
3 subdomain := GetCustomDomain(hostDomain, cfg.Space)
4 if subdomain != "" {
5- u, err := dbpool.FindUserForName(subdomain)
6+ u, err := dbpool.FindUserByName(subdomain)
7 if u != nil && err == nil {
8 w.WriteHeader(http.StatusOK)
9 return
1@@ -90,7 +90,7 @@ func UserFeed(me db.DB, user *db.User, token string) (*feeds.Feed, error) {
2 var feedItems []*feeds.Item
3
4 now := time.Now()
5- ff, err := me.FindFeatureForUser(user.ID, "plus")
6+ ff, err := me.FindFeature(user.ID, "plus")
7 if err != nil {
8 // still want to send an empty feed
9 } else {
1@@ -3,6 +3,7 @@ package shared
2 import (
3 "fmt"
4 "log/slog"
5+ "strings"
6
7 "github.com/picosh/pico/pkg/db"
8 "github.com/picosh/utils"
9@@ -16,6 +17,8 @@ type SshAuthHandler struct {
10
11 type AuthFindUser interface {
12 FindUserByPubkey(key string) (*db.User, error)
13+ FindUserByName(name string) (*db.User, error)
14+ FindFeature(userID, name string) (*db.FeatureFlag, error)
15 }
16
17 func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger) *SshAuthHandler {
18@@ -43,8 +46,29 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
19 return nil, fmt.Errorf("username is not set")
20 }
21
22+ // impersonation
23+ impID := user.ID
24+ adminPrefix := "admin__"
25+ usr := conn.User()
26+ if strings.HasPrefix(usr, adminPrefix) {
27+ ff, err := r.DB.FindFeature(user.ID, "admin")
28+ if err != nil {
29+ return nil, fmt.Errorf("only admins can impersonate a user: %w", err)
30+ }
31+ if !ff.IsValid() {
32+ return nil, fmt.Errorf("expired admin feature flag, cannot impersonate a user")
33+ }
34+
35+ impersonate := strings.Replace(usr, adminPrefix, "", 1)
36+ user, err = r.DB.FindUserByName(impersonate)
37+ if err != nil {
38+ return nil, err
39+ }
40+ }
41+
42 return &ssh.Permissions{
43 Extensions: map[string]string{
44+ "imp_id": impID,
45 "user_id": user.ID,
46 "pubkey": pubkey,
47 },
48@@ -52,7 +76,7 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
49 }
50
51 func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag {
52- ff, _ := dbpool.FindFeatureForUser(userID, "plus")
53+ ff, _ := dbpool.FindFeature(userID, "plus")
54 // we have free tiers so users might not have a feature flag
55 // in which case we set sane defaults
56 if ff == nil {
+1,
-1
1@@ -230,7 +230,7 @@ func (m *TunsPage) HandleEvent(ev vaxis.Event, ph vxfw.EventPhase) (vxfw.Command
2 switch msg := ev.(type) {
3 case PageIn:
4 m.loading = true
5- ff, _ := m.shared.Dbpool.FindFeatureForUser(m.shared.User.ID, "admin")
6+ ff, _ := m.shared.Dbpool.FindFeature(m.shared.User.ID, "admin")
7 if ff != nil {
8 m.isAdmin = true
9 }
+2,
-2
1@@ -296,7 +296,7 @@ func FindUser(shrd *SharedModel) (*db.User, error) {
2 return nil, fmt.Errorf("only admins can impersonate a user")
3 }
4 impersonate := strings.Replace(usr, adminPrefix, "", 1)
5- user, err = shrd.Dbpool.FindUserForName(impersonate)
6+ user, err = shrd.Dbpool.FindUserByName(impersonate)
7 if err != nil {
8 return nil, err
9 }
10@@ -311,7 +311,7 @@ func FindFeatureFlag(shrd *SharedModel, name string) (*db.FeatureFlag, error) {
11 return nil, nil
12 }
13
14- ff, err := shrd.Dbpool.FindFeatureForUser(shrd.User.ID, name)
15+ ff, err := shrd.Dbpool.FindFeature(shrd.User.ID, name)
16 if err != nil {
17 return nil, err
18 }