Eric Bower
·
2025-03-28
api_test.go
1package auth
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12 "time"
13
14 "github.com/gkampitakis/go-snaps/snaps"
15 "github.com/picosh/pico/pkg/db"
16 "github.com/picosh/pico/pkg/db/stub"
17 "github.com/picosh/pico/pkg/shared"
18)
19
20var testUserID = "user-1"
21var testUsername = "user-a"
22
23func TestPaymentWebhook(t *testing.T) {
24 apiConfig := setupTest()
25
26 event := OrderEvent{
27 Meta: &OrderEventMeta{
28 EventName: "order_created",
29 CustomData: &CustomDataMeta{
30 PicoUsername: testUsername,
31 },
32 },
33 Data: &OrderEventData{
34 Attr: &OrderEventDataAttr{
35 UserEmail: "auth@pico.test",
36 CreatedAt: time.Now(),
37 Status: "paid",
38 OrderNumber: 1337,
39 },
40 },
41 }
42 jso, err := json.Marshal(event)
43 bail(err)
44 hash := shared.HmacString(apiConfig.Cfg.SecretWebhook, string(jso))
45 body := bytes.NewReader(jso)
46
47 request := httptest.NewRequest("POST", mkpath("/webhook"), body)
48 request.Header.Add("X-signature", hash)
49 responseRecorder := httptest.NewRecorder()
50
51 mux := authMux(apiConfig)
52 mux.ServeHTTP(responseRecorder, request)
53
54 testResponse(t, responseRecorder, 200, "text/plain")
55
56 posts, err := apiConfig.Dbpool.FindPostsForUser(&db.Pager{Num: 1000, Page: 0}, testUserID, "feeds")
57 if err != nil {
58 t.Error("could not find posts for user")
59 }
60 for _, post := range posts.Data {
61 if post.Filename != "pico-plus" {
62 continue
63 }
64 expectedText := `=: email auth@pico.test
65=: digest_interval 10min
66=: inline_content true
67=> https://auth.pico.sh/rss/123
68=> https://blog.pico.sh/rss`
69 if post.Text != expectedText {
70 t.Errorf("Want pico plus feed file %s, got %s", expectedText, post.Text)
71 }
72 }
73}
74
75func TestUser(t *testing.T) {
76 apiConfig := setupTest()
77
78 data := sishData{
79 Username: testUsername,
80 }
81 jso, err := json.Marshal(data)
82 bail(err)
83 body := bytes.NewReader(jso)
84
85 request := httptest.NewRequest("POST", mkpath("/user"), body)
86 request.Header.Add("Authorization", "Bearer 123")
87 responseRecorder := httptest.NewRecorder()
88
89 mux := authMux(apiConfig)
90 mux.ServeHTTP(responseRecorder, request)
91
92 testResponse(t, responseRecorder, 200, "application/json")
93}
94
95func TestKey(t *testing.T) {
96 apiConfig := setupTest()
97
98 data := sishData{
99 Username: testUsername,
100 PublicKey: "zzz",
101 }
102 jso, err := json.Marshal(data)
103 bail(err)
104 body := bytes.NewReader(jso)
105
106 request := httptest.NewRequest("POST", mkpath("/key"), body)
107 request.Header.Add("Authorization", "Bearer 123")
108 responseRecorder := httptest.NewRecorder()
109
110 mux := authMux(apiConfig)
111 mux.ServeHTTP(responseRecorder, request)
112
113 testResponse(t, responseRecorder, 200, "application/json")
114}
115
116func TestCheckout(t *testing.T) {
117 apiConfig := setupTest()
118
119 request := httptest.NewRequest("GET", mkpath("/checkout/"+testUsername), strings.NewReader(""))
120 request.Header.Add("Authorization", "Bearer 123")
121 responseRecorder := httptest.NewRecorder()
122
123 mux := authMux(apiConfig)
124 mux.ServeHTTP(responseRecorder, request)
125
126 loc := responseRecorder.Header().Get("Location")
127 if loc != "https://checkout.pico.sh/buy/73c26cf9-3fac-44c3-b744-298b3032a96b?discount=0&checkout[custom][username]=user-a" {
128 t.Errorf("Have Location %s, want checkout", loc)
129 }
130 if responseRecorder.Code != http.StatusMovedPermanently {
131 t.Errorf("Want status '%d', got '%d'", http.StatusMovedPermanently, responseRecorder.Code)
132 return
133 }
134}
135
136func TestIntrospect(t *testing.T) {
137 apiConfig := setupTest()
138
139 request := httptest.NewRequest("POST", mkpath("/introspect?token=123"), strings.NewReader(""))
140 responseRecorder := httptest.NewRecorder()
141
142 mux := authMux(apiConfig)
143 mux.ServeHTTP(responseRecorder, request)
144
145 testResponse(t, responseRecorder, 200, "application/json")
146}
147
148func TestToken(t *testing.T) {
149 apiConfig := setupTest()
150
151 request := httptest.NewRequest("POST", mkpath("/token?code=123"), strings.NewReader(""))
152 responseRecorder := httptest.NewRecorder()
153
154 mux := authMux(apiConfig)
155 mux.ServeHTTP(responseRecorder, request)
156
157 testResponse(t, responseRecorder, 200, "application/json")
158}
159
160func TestAuthApi(t *testing.T) {
161 apiConfig := setupTest()
162 tt := []*ApiExample{
163 {
164 name: "authorize",
165 path: "/authorize?response_type=json&client_id=333&redirect_uri=pico.test&scope=admin",
166 status: http.StatusOK,
167 contentType: "text/html; charset=utf-8",
168 dbpool: apiConfig.Dbpool,
169 },
170 {
171 name: "rss",
172 path: "/rss/123",
173 status: http.StatusOK,
174 contentType: "application/atom+xml",
175 dbpool: apiConfig.Dbpool,
176 },
177 {
178 name: "fileserver",
179 path: "/robots.txt",
180 status: http.StatusOK,
181 contentType: "text/plain; charset=utf-8",
182 dbpool: apiConfig.Dbpool,
183 },
184 {
185 name: "well-known",
186 path: "/.well-known/oauth-authorization-server",
187 status: http.StatusOK,
188 contentType: "application/json",
189 dbpool: apiConfig.Dbpool,
190 },
191 }
192
193 for _, tc := range tt {
194 t.Run(tc.name, func(t *testing.T) {
195 request := httptest.NewRequest("GET", mkpath(tc.path), strings.NewReader(""))
196 responseRecorder := httptest.NewRecorder()
197
198 mux := authMux(apiConfig)
199 mux.ServeHTTP(responseRecorder, request)
200
201 testResponse(t, responseRecorder, tc.status, tc.contentType)
202 })
203 }
204}
205
206type ApiExample struct {
207 name string
208 path string
209 status int
210 contentType string
211 dbpool db.DB
212}
213
214type AuthDb struct {
215 *stub.StubDB
216 Posts []*db.Post
217}
218
219func (a *AuthDb) AddPicoPlusUser(username, email, from, txid string) error {
220 return nil
221}
222
223func (a *AuthDb) FindUserByName(username string) (*db.User, error) {
224 return &db.User{ID: testUserID, Name: username}, nil
225}
226
227func (a *AuthDb) FindUserForKey(username string, pubkey string) (*db.User, error) {
228 return &db.User{ID: testUserID, Name: username}, nil
229}
230
231func (a *AuthDb) FindUserForToken(token string) (*db.User, error) {
232 if token != "123" {
233 return nil, fmt.Errorf("invalid token")
234 }
235 return &db.User{ID: testUserID, Name: testUsername}, nil
236}
237
238func (a *AuthDb) HasFeatureForUser(userID string, feature string) bool {
239 return true
240}
241
242func (a *AuthDb) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
243 return []*db.PublicKey{{ID: "1", UserID: user.ID, Name: "my-key", Key: "nice-pubkey", CreatedAt: &time.Time{}}}, nil
244}
245
246func (a *AuthDb) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
247 now := time.Date(2021, 8, 15, 14, 30, 45, 100, time.UTC)
248 oneDayWarning := now.AddDate(0, 0, 1)
249 return &db.FeatureFlag{ID: "2", UserID: userID, Name: "plus", ExpiresAt: &oneDayWarning, CreatedAt: &now}, nil
250}
251
252func (a *AuthDb) InsertPost(post *db.Post) (*db.Post, error) {
253 a.Posts = append(a.Posts, post)
254 return post, nil
255}
256
257func (a *AuthDb) FindPostsForUser(pager *db.Pager, userID, space string) (*db.Paginate[*db.Post], error) {
258 return &db.Paginate[*db.Post]{
259 Data: a.Posts,
260 }, nil
261}
262
263func (a *AuthDb) UpsertToken(string, string) (string, error) {
264 return "123", nil
265}
266
267func NewAuthDb(logger *slog.Logger) *AuthDb {
268 sb := stub.NewStubDB(logger)
269 return &AuthDb{
270 StubDB: sb,
271 }
272}
273
274func mkpath(path string) string {
275 return fmt.Sprintf("https://auth.pico.test%s", path)
276}
277
278func setupTest() *shared.ApiConfig {
279 logger := shared.CreateLogger("auth-test")
280 cfg := &shared.ConfigSite{
281 Issuer: "auth.pico.test",
282 Domain: "http://0.0.0.0:3000",
283 Port: "3000",
284 Secret: "",
285 SecretWebhook: "my-secret",
286 }
287 cfg.Logger = logger
288 db := NewAuthDb(cfg.Logger)
289 apiConfig := &shared.ApiConfig{
290 Cfg: cfg,
291 Dbpool: db,
292 }
293
294 return apiConfig
295}
296
297func testResponse(t *testing.T, responseRecorder *httptest.ResponseRecorder, status int, contentType string) {
298 if responseRecorder.Code != status {
299 t.Errorf("Want status '%d', got '%d'", status, responseRecorder.Code)
300 return
301 }
302
303 ct := responseRecorder.Header().Get("content-type")
304 if ct != contentType {
305 t.Errorf("Want content type '%s', got '%s'", contentType, ct)
306 return
307 }
308
309 body := strings.TrimSpace(responseRecorder.Body.String())
310 snaps.MatchSnapshot(t, body)
311}
312
313func bail(err error) {
314 if err != nil {
315 panic(bail)
316 }
317}