repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

commit
461e887
parent
9d5a8dd
author
Eric Bower
date
2025-03-28 11:53:44 -0400 EDT
feat(pgs): admins can impersonate

This change will let admins impersonate any user for the pgs cli
14 files changed,  +64, -52
M pkg/apps/auth/api.go
+3, -3
 1@@ -303,7 +303,7 @@ func userHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
 2 			"publicKey", data.PublicKey,
 3 		)
 4 
 5-		user, err := apiConfig.Dbpool.FindUserForName(data.Username)
 6+		user, err := apiConfig.Dbpool.FindUserByName(data.Username)
 7 		if err != nil {
 8 			apiConfig.Cfg.Logger.Error(err.Error())
 9 			http.Error(w, err.Error(), http.StatusNotFound)
10@@ -461,7 +461,7 @@ func paymentWebhookHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
11 		status := event.Data.Attr.Status
12 		txID := fmt.Sprint(event.Data.Attr.OrderNumber)
13 
14-		user, err := apiConfig.Dbpool.FindUserForName(username)
15+		user, err := apiConfig.Dbpool.FindUserByName(username)
16 		if err != nil {
17 			logger.Error("no user found with username", "username", username)
18 			w.WriteHeader(http.StatusOK)
19@@ -624,7 +624,7 @@ func deserializeCaddyAccessLog(dbpool db.DB, access *AccessLog) (*db.AnalyticsVi
20 	}
21 
22 	// get user ID
23-	user, err := dbpool.FindUserForName(props.Username)
24+	user, err := dbpool.FindUserByName(props.Username)
25 	if err != nil {
26 		return nil, fmt.Errorf("could not find user for name %s: %w", props.Username, err)
27 	}
M pkg/apps/auth/api_test.go
+2, -2
 1@@ -220,7 +220,7 @@ func (a *AuthDb) AddPicoPlusUser(username, email, from, txid string) error {
 2 	return nil
 3 }
 4 
 5-func (a *AuthDb) FindUserForName(username string) (*db.User, error) {
 6+func (a *AuthDb) FindUserByName(username string) (*db.User, error) {
 7 	return &db.User{ID: testUserID, Name: username}, nil
 8 }
 9 
10@@ -243,7 +243,7 @@ func (a *AuthDb) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
11 	return []*db.PublicKey{{ID: "1", UserID: user.ID, Name: "my-key", Key: "nice-pubkey", CreatedAt: &time.Time{}}}, nil
12 }
13 
14-func (a *AuthDb) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
15+func (a *AuthDb) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
16 	now := time.Date(2021, 8, 15, 14, 30, 45, 100, time.UTC)
17 	oneDayWarning := now.AddDate(0, 0, 1)
18 	return &db.FeatureFlag{ID: "2", UserID: userID, Name: "plus", ExpiresAt: &oneDayWarning, CreatedAt: &now}, nil
M pkg/apps/pastes/api.go
+3, -3
 1@@ -78,7 +78,7 @@ func blogHandler(w http.ResponseWriter, r *http.Request) {
 2 	logger := blogger.With("user", username)
 3 	cfg := shared.GetCfg(r)
 4 
 5-	user, err := dbpool.FindUserForName(username)
 6+	user, err := dbpool.FindUserByName(username)
 7 	if err != nil {
 8 		logger.Info("user not found")
 9 		http.Error(w, "user not found", http.StatusNotFound)
10@@ -170,7 +170,7 @@ func postHandler(w http.ResponseWriter, r *http.Request) {
11 	blogger := shared.GetLogger(r)
12 	logger := blogger.With("slug", slug, "user", username)
13 
14-	user, err := dbpool.FindUserForName(username)
15+	user, err := dbpool.FindUserByName(username)
16 	if err != nil {
17 		logger.Info("paste not found")
18 		http.Error(w, "paste not found", http.StatusNotFound)
19@@ -271,7 +271,7 @@ func postHandlerRaw(w http.ResponseWriter, r *http.Request) {
20 	blogger := shared.GetLogger(r)
21 	logger := blogger.With("user", username, "slug", slug)
22 
23-	user, err := dbpool.FindUserForName(username)
24+	user, err := dbpool.FindUserByName(username)
25 	if err != nil {
26 		logger.Info("user not found")
27 		http.Error(w, "user not found", http.StatusNotFound)
M pkg/apps/pgs/cli_middleware.go
+4, -16
 1@@ -10,26 +10,14 @@ import (
 2 	"github.com/picosh/pico/pkg/db"
 3 	"github.com/picosh/pico/pkg/pssh"
 4 	sendutils "github.com/picosh/pico/pkg/send/utils"
 5-	"github.com/picosh/utils"
 6 )
 7 
 8 func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
 9-	if s.PublicKey() == nil {
10-		return nil, fmt.Errorf("key not found")
11+	userID, ok := s.Conn.Permissions.Extensions["user_id"]
12+	if !ok {
13+		return nil, fmt.Errorf("`user_id` extension not found")
14 	}
15-
16-	key := utils.KeyForKeyText(s.PublicKey())
17-
18-	user, err := dbpool.FindUserByPubkey(key)
19-	if err != nil {
20-		return nil, err
21-	}
22-
23-	if user.Name == "" {
24-		return nil, fmt.Errorf("must have username set")
25-	}
26-
27-	return user, nil
28+	return dbpool.FindUser(userID)
29 }
30 
31 type arrayFlags []string
M pkg/apps/pico/cli.go
+2, -2
 1@@ -117,10 +117,10 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
 2 					return err
 3 				}
 4 
 5-				ff, err := dbpool.FindFeatureForUser(user.ID, "plus")
 6+				ff, err := dbpool.FindFeature(user.ID, "plus")
 7 				if err != nil {
 8 					handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
 9-					ff, err = dbpool.FindFeatureForUser(user.ID, "bouncer")
10+					ff, err = dbpool.FindFeature(user.ID, "bouncer")
11 					if err != nil {
12 						handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
13 						sesh.Fatal(err)
M pkg/apps/prose/api.go
+6, -6
 1@@ -125,7 +125,7 @@ func blogStyleHandler(w http.ResponseWriter, r *http.Request) {
 2 	logger := shared.GetLogger(r)
 3 	cfg := shared.GetCfg(r)
 4 
 5-	user, err := dbpool.FindUserForName(username)
 6+	user, err := dbpool.FindUserByName(username)
 7 	if err != nil {
 8 		logger.Info("blog not found", "user", username)
 9 		http.Error(w, "blog not found", http.StatusNotFound)
10@@ -155,7 +155,7 @@ func blogHandler(w http.ResponseWriter, r *http.Request) {
11 	logger := shared.GetLogger(r)
12 	cfg := shared.GetCfg(r)
13 
14-	user, err := dbpool.FindUserForName(username)
15+	user, err := dbpool.FindUserByName(username)
16 	if err != nil {
17 		logger.Info("blog not found", "user", username)
18 		http.Error(w, "blog not found", http.StatusNotFound)
19@@ -301,7 +301,7 @@ func postRawHandler(w http.ResponseWriter, r *http.Request) {
20 	logger := shared.GetLogger(r)
21 	logger = logger.With("slug", slug)
22 
23-	user, err := dbpool.FindUserForName(username)
24+	user, err := dbpool.FindUserByName(username)
25 	if err != nil {
26 		logger.Info("blog not found", "user", username)
27 		http.Error(w, "blog not found", http.StatusNotFound)
28@@ -341,7 +341,7 @@ func postHandler(w http.ResponseWriter, r *http.Request) {
29 	dbpool := shared.GetDB(r)
30 	logger := shared.GetLogger(r)
31 
32-	user, err := dbpool.FindUserForName(username)
33+	user, err := dbpool.FindUserByName(username)
34 	if err != nil {
35 		logger.Info("blog not found", "user", username)
36 		http.Error(w, "blog not found", http.StatusNotFound)
37@@ -589,7 +589,7 @@ func rssBlogHandler(w http.ResponseWriter, r *http.Request) {
38 	logger := shared.GetLogger(r)
39 	cfg := shared.GetCfg(r)
40 
41-	user, err := dbpool.FindUserForName(username)
42+	user, err := dbpool.FindUserByName(username)
43 	if err != nil {
44 		logger.Info("rss feed not found", "user", username)
45 		http.Error(w, "rss feed not found", http.StatusNotFound)
46@@ -852,7 +852,7 @@ func imgRequest(w http.ResponseWriter, r *http.Request) {
47 	logger := shared.GetLogger(r)
48 	dbpool := shared.GetDB(r)
49 	username := shared.GetUsernameFromRequest(r)
50-	user, err := dbpool.FindUserForName(username)
51+	user, err := dbpool.FindUserByName(username)
52 	if err != nil {
53 		logger.Error("could not find user", "username", username)
54 		http.Error(w, "could find user", http.StatusNotFound)
M pkg/db/db.go
+3, -3
 1@@ -212,7 +212,7 @@ type Token struct {
 2 type FeatureFlag struct {
 3 	ID               string          `json:"id" db:"id"`
 4 	UserID           string          `json:"user_id" db:"user_id"`
 5-	PaymentHistoryID string          `json:"payment_history_id" db:"payment_history_id"`
 6+	PaymentHistoryID sql.NullString  `json:"payment_history_id" db:"payment_history_id"`
 7 	Name             string          `json:"name" db:"name"`
 8 	CreatedAt        *time.Time      `json:"created_at" db:"created_at"`
 9 	ExpiresAt        *time.Time      `json:"expires_at" db:"expires_at"`
10@@ -370,7 +370,7 @@ type DB interface {
11 	RemoveKeys(pubkeyIDs []string) error
12 
13 	FindUsers() ([]*User, error)
14-	FindUserForName(name string) (*User, error)
15+	FindUserByName(name string) (*User, error)
16 	FindUserForNameAndKey(name string, pubkey string) (*User, error)
17 	FindUserForKey(name string, pubkey string) (*User, error)
18 	FindUserByPubkey(pubkey string) (*User, error)
19@@ -414,7 +414,7 @@ type DB interface {
20 	FindVisitSiteList(opts *SummaryOpts) ([]*VisitUrl, error)
21 
22 	AddPicoPlusUser(username, email, paymentType, txId string) error
23-	FindFeatureForUser(userID string, feature string) (*FeatureFlag, error)
24+	FindFeature(userID string, feature string) (*FeatureFlag, error)
25 	FindFeaturesForUser(userID string) ([]*FeatureFlag, error)
26 	HasFeatureForUser(userID string, feature string) bool
27 	FindTotalSizeForUser(userID string) (int, error)
M pkg/db/postgres/storage.go
+9, -9
 1@@ -617,14 +617,14 @@ func (me *PsqlDB) ValidateName(name string) (bool, error) {
 2 	if !v {
 3 		return false, fmt.Errorf("%s is invalid: %w", lower, db.ErrNameInvalid)
 4 	}
 5-	user, _ := me.FindUserForName(lower)
 6+	user, _ := me.FindUserByName(lower)
 7 	if user == nil {
 8 		return true, nil
 9 	}
10 	return false, fmt.Errorf("%s already taken: %w", lower, db.ErrNameTaken)
11 }
12 
13-func (me *PsqlDB) FindUserForName(name string) (*db.User, error) {
14+func (me *PsqlDB) FindUserByName(name string) (*db.User, error) {
15 	user := &db.User{}
16 	r := me.Db.QueryRow(sqlSelectUserForName, strings.ToLower(name))
17 	err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
18@@ -1457,7 +1457,7 @@ func (me *PsqlDB) FindTagsForPost(postID string) ([]string, error) {
19 	return tags, nil
20 }
21 
22-func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
23+func (me *PsqlDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
24 	ff := &db.FeatureFlag{}
25 	// payment history is allowed to be null
26 	// https://devtidbits.com/2020/08/03/go-sql-error-converting-null-to-string-is-unsupported/
27@@ -1475,7 +1475,7 @@ func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.Feature
28 		return nil, err
29 	}
30 
31-	ff.PaymentHistoryID = paymentHistoryID.String
32+	ff.PaymentHistoryID = paymentHistoryID
33 
34 	return ff, nil
35 }
36@@ -1507,7 +1507,7 @@ func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error)
37 		if err != nil {
38 			return features, err
39 		}
40-		ff.PaymentHistoryID = paymentHistoryID.String
41+		ff.PaymentHistoryID = paymentHistoryID
42 
43 		features = append(features, ff)
44 	}
45@@ -1518,7 +1518,7 @@ func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error)
46 }
47 
48 func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
49-	ff, err := me.FindFeatureForUser(userID, feature)
50+	ff, err := me.FindFeature(userID, feature)
51 	if err != nil {
52 		return false
53 	}
54@@ -1695,7 +1695,7 @@ func (me *PsqlDB) InsertFeature(userID, name string, expiresAt time.Time) (*db.F
55 		return nil, err
56 	}
57 
58-	feature, err := me.FindFeatureForUser(userID, name)
59+	feature, err := me.FindFeature(userID, name)
60 	if err != nil {
61 		return nil, err
62 	}
63@@ -1709,7 +1709,7 @@ func (me *PsqlDB) RemoveFeature(userID string, name string) error {
64 }
65 
66 func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
67-	ff, _ := me.FindFeatureForUser(userID, name)
68+	ff, _ := me.FindFeature(userID, name)
69 	if ff == nil {
70 		t := time.Now()
71 		return t.AddDate(1, 0, 0)
72@@ -1718,7 +1718,7 @@ func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
73 }
74 
75 func (me *PsqlDB) AddPicoPlusUser(username, email, paymentType, txId string) error {
76-	user, err := me.FindUserForName(username)
77+	user, err := me.FindUserByName(username)
78 	if err != nil {
79 		return err
80 	}
M pkg/db/stub/stub.go
+2, -2
 1@@ -77,7 +77,7 @@ func (me *StubDB) ValidateName(name string) (bool, error) {
 2 	return false, notImpl
 3 }
 4 
 5-func (me *StubDB) FindUserForName(name string) (*db.User, error) {
 6+func (me *StubDB) FindUserByName(name string) (*db.User, error) {
 7 	return nil, notImpl
 8 }
 9 
10@@ -189,7 +189,7 @@ func (me *StubDB) FindTagsForPost(postID string) ([]string, error) {
11 	return []string{}, notImpl
12 }
13 
14-func (me *StubDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
15+func (me *StubDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
16 	return nil, notImpl
17 }
18 
M pkg/shared/api.go
+1, -1
1@@ -76,7 +76,7 @@ func CheckHandler(w http.ResponseWriter, r *http.Request) {
2 		if !strings.Contains(hostDomain, appDomain) {
3 			subdomain := GetCustomDomain(hostDomain, cfg.Space)
4 			if subdomain != "" {
5-				u, err := dbpool.FindUserForName(subdomain)
6+				u, err := dbpool.FindUserByName(subdomain)
7 				if u != nil && err == nil {
8 					w.WriteHeader(http.StatusOK)
9 					return
M pkg/shared/feed.go
+1, -1
1@@ -90,7 +90,7 @@ func UserFeed(me db.DB, user *db.User, token string) (*feeds.Feed, error) {
2 	var feedItems []*feeds.Item
3 
4 	now := time.Now()
5-	ff, err := me.FindFeatureForUser(user.ID, "plus")
6+	ff, err := me.FindFeature(user.ID, "plus")
7 	if err != nil {
8 		// still want to send an empty feed
9 	} else {
M pkg/shared/ssh.go
+25, -1
 1@@ -3,6 +3,7 @@ package shared
 2 import (
 3 	"fmt"
 4 	"log/slog"
 5+	"strings"
 6 
 7 	"github.com/picosh/pico/pkg/db"
 8 	"github.com/picosh/utils"
 9@@ -16,6 +17,8 @@ type SshAuthHandler struct {
10 
11 type AuthFindUser interface {
12 	FindUserByPubkey(key string) (*db.User, error)
13+	FindUserByName(name string) (*db.User, error)
14+	FindFeature(userID, name string) (*db.FeatureFlag, error)
15 }
16 
17 func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger) *SshAuthHandler {
18@@ -43,8 +46,29 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
19 		return nil, fmt.Errorf("username is not set")
20 	}
21 
22+	// impersonation
23+	impID := user.ID
24+	adminPrefix := "admin__"
25+	usr := conn.User()
26+	if strings.HasPrefix(usr, adminPrefix) {
27+		ff, err := r.DB.FindFeature(user.ID, "admin")
28+		if err != nil {
29+			return nil, fmt.Errorf("only admins can impersonate a user: %w", err)
30+		}
31+		if !ff.IsValid() {
32+			return nil, fmt.Errorf("expired admin feature flag, cannot impersonate a user")
33+		}
34+
35+		impersonate := strings.Replace(usr, adminPrefix, "", 1)
36+		user, err = r.DB.FindUserByName(impersonate)
37+		if err != nil {
38+			return nil, err
39+		}
40+	}
41+
42 	return &ssh.Permissions{
43 		Extensions: map[string]string{
44+			"imp_id":  impID,
45 			"user_id": user.ID,
46 			"pubkey":  pubkey,
47 		},
48@@ -52,7 +76,7 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
49 }
50 
51 func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag {
52-	ff, _ := dbpool.FindFeatureForUser(userID, "plus")
53+	ff, _ := dbpool.FindFeature(userID, "plus")
54 	// we have free tiers so users might not have a feature flag
55 	// in which case we set sane defaults
56 	if ff == nil {
M pkg/tui/tuns.go
+1, -1
1@@ -230,7 +230,7 @@ func (m *TunsPage) HandleEvent(ev vaxis.Event, ph vxfw.EventPhase) (vxfw.Command
2 	switch msg := ev.(type) {
3 	case PageIn:
4 		m.loading = true
5-		ff, _ := m.shared.Dbpool.FindFeatureForUser(m.shared.User.ID, "admin")
6+		ff, _ := m.shared.Dbpool.FindFeature(m.shared.User.ID, "admin")
7 		if ff != nil {
8 			m.isAdmin = true
9 		}
M pkg/tui/ui.go
+2, -2
 1@@ -296,7 +296,7 @@ func FindUser(shrd *SharedModel) (*db.User, error) {
 2 			return nil, fmt.Errorf("only admins can impersonate a user")
 3 		}
 4 		impersonate := strings.Replace(usr, adminPrefix, "", 1)
 5-		user, err = shrd.Dbpool.FindUserForName(impersonate)
 6+		user, err = shrd.Dbpool.FindUserByName(impersonate)
 7 		if err != nil {
 8 			return nil, err
 9 		}
10@@ -311,7 +311,7 @@ func FindFeatureFlag(shrd *SharedModel, name string) (*db.FeatureFlag, error) {
11 		return nil, nil
12 	}
13 
14-	ff, err := shrd.Dbpool.FindFeatureForUser(shrd.User.ID, name)
15+	ff, err := shrd.Dbpool.FindFeature(shrd.User.ID, name)
16 	if err != nil {
17 		return nil, err
18 	}