Eric Bower
·
2026-01-25
api_test.go
1package auth
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12 "time"
13
14 "github.com/gkampitakis/go-snaps/snaps"
15 "github.com/picosh/pico/pkg/db"
16 "github.com/picosh/pico/pkg/db/stub"
17 "github.com/picosh/pico/pkg/shared"
18 "github.com/picosh/pico/pkg/shared/router"
19)
20
21var testUserID = "user-1"
22var testUsername = "user-a"
23
24func TestPaymentWebhook(t *testing.T) {
25 apiConfig := setupTest()
26
27 event := OrderEvent{
28 Meta: &OrderEventMeta{
29 EventName: "order_created",
30 CustomData: &CustomDataMeta{
31 PicoUsername: testUsername,
32 },
33 },
34 Data: &OrderEventData{
35 Attr: &OrderEventDataAttr{
36 UserEmail: "auth@pico.test",
37 CreatedAt: time.Now(),
38 Status: "paid",
39 OrderNumber: 1337,
40 },
41 },
42 }
43 jso, err := json.Marshal(event)
44 bail(err)
45 hash := router.HmacString(apiConfig.Cfg.SecretWebhook, string(jso))
46 body := bytes.NewReader(jso)
47
48 request := httptest.NewRequest("POST", mkpath("/webhook"), body)
49 request.Header.Add("X-signature", hash)
50 responseRecorder := httptest.NewRecorder()
51
52 mux := authMux(apiConfig)
53 mux.ServeHTTP(responseRecorder, request)
54
55 testResponse(t, responseRecorder, 200, "text/plain")
56
57 posts, err := apiConfig.Dbpool.FindPostsByUser(&db.Pager{Num: 1000, Page: 0}, testUserID, "feeds")
58 if err != nil {
59 t.Error("could not find posts for user")
60 }
61 for _, post := range posts.Data {
62 if post.Filename != "pico-plus" {
63 continue
64 }
65 expectedText := `=: email auth@pico.test
66=: cron */10 * * * *
67=: inline_content true
68=> https://auth.pico.sh/rss/123
69=> https://blog.pico.sh/rss`
70 if post.Text != expectedText {
71 t.Errorf("Want pico plus feed file %s, got %s", expectedText, post.Text)
72 }
73 }
74}
75
76func TestUser(t *testing.T) {
77 apiConfig := setupTest()
78
79 data := sishData{
80 Username: testUsername,
81 }
82 jso, err := json.Marshal(data)
83 bail(err)
84 body := bytes.NewReader(jso)
85
86 request := httptest.NewRequest("POST", mkpath("/user"), body)
87 request.Header.Add("Authorization", "Bearer 123")
88 responseRecorder := httptest.NewRecorder()
89
90 mux := authMux(apiConfig)
91 mux.ServeHTTP(responseRecorder, request)
92
93 testResponse(t, responseRecorder, 200, "application/json")
94}
95
96func TestKey(t *testing.T) {
97 apiConfig := setupTest()
98
99 data := sishData{
100 Username: testUsername,
101 PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxVPgEqtWOa5l0QHZV6TQKhV+l46SAXU07c9RuHlGka test@pico",
102 }
103 jso, err := json.Marshal(data)
104 bail(err)
105 body := bytes.NewReader(jso)
106
107 request := httptest.NewRequest("POST", mkpath("/key"), body)
108 request.Header.Add("Authorization", "Bearer 123")
109 responseRecorder := httptest.NewRecorder()
110
111 mux := authMux(apiConfig)
112 mux.ServeHTTP(responseRecorder, request)
113
114 testResponse(t, responseRecorder, 200, "application/json")
115}
116
117func TestCheckout(t *testing.T) {
118 apiConfig := setupTest()
119
120 request := httptest.NewRequest("GET", mkpath("/checkout/"+testUsername), strings.NewReader(""))
121 request.Header.Add("Authorization", "Bearer 123")
122 responseRecorder := httptest.NewRecorder()
123
124 mux := authMux(apiConfig)
125 mux.ServeHTTP(responseRecorder, request)
126
127 loc := responseRecorder.Header().Get("Location")
128 if loc != "https://picosh.lemonsqueezy.com/buy/73c26cf9-3fac-44c3-b744-298b3032a96b?discount=0&checkout[custom][username]=user-a" {
129 t.Errorf("Have Location %s, want checkout", loc)
130 }
131 if responseRecorder.Code != http.StatusMovedPermanently {
132 t.Errorf("Want status '%d', got '%d'", http.StatusMovedPermanently, responseRecorder.Code)
133 return
134 }
135}
136
137func TestIntrospect(t *testing.T) {
138 apiConfig := setupTest()
139
140 request := httptest.NewRequest("POST", mkpath("/introspect?token=123"), strings.NewReader(""))
141 responseRecorder := httptest.NewRecorder()
142
143 mux := authMux(apiConfig)
144 mux.ServeHTTP(responseRecorder, request)
145
146 testResponse(t, responseRecorder, 200, "application/json")
147}
148
149func TestToken(t *testing.T) {
150 apiConfig := setupTest()
151
152 request := httptest.NewRequest("POST", mkpath("/token?code=123"), strings.NewReader(""))
153 responseRecorder := httptest.NewRecorder()
154
155 mux := authMux(apiConfig)
156 mux.ServeHTTP(responseRecorder, request)
157
158 testResponse(t, responseRecorder, 200, "application/json")
159}
160
161func TestAuthApi(t *testing.T) {
162 apiConfig := setupTest()
163 tt := []*ApiExample{
164 {
165 name: "authorize",
166 path: "/authorize?response_type=json&client_id=333&redirect_uri=pico.test&scope=admin",
167 status: http.StatusOK,
168 contentType: "text/html; charset=utf-8",
169 dbpool: apiConfig.Dbpool,
170 },
171 {
172 name: "rss",
173 path: "/rss/123",
174 status: http.StatusOK,
175 contentType: "application/atom+xml",
176 dbpool: apiConfig.Dbpool,
177 },
178 {
179 name: "fileserver",
180 path: "/robots.txt",
181 status: http.StatusOK,
182 contentType: "text/plain; charset=utf-8",
183 dbpool: apiConfig.Dbpool,
184 },
185 {
186 name: "well-known",
187 path: "/.well-known/oauth-authorization-server",
188 status: http.StatusOK,
189 contentType: "application/json",
190 dbpool: apiConfig.Dbpool,
191 },
192 }
193
194 for _, tc := range tt {
195 t.Run(tc.name, func(t *testing.T) {
196 request := httptest.NewRequest("GET", mkpath(tc.path), strings.NewReader(""))
197 responseRecorder := httptest.NewRecorder()
198
199 mux := authMux(apiConfig)
200 mux.ServeHTTP(responseRecorder, request)
201
202 testResponse(t, responseRecorder, tc.status, tc.contentType)
203 })
204 }
205}
206
207type ApiExample struct {
208 name string
209 path string
210 status int
211 contentType string
212 dbpool db.DB
213}
214
215type AuthDb struct {
216 *stub.StubDB
217 Posts []*db.Post
218}
219
220func (a *AuthDb) AddPicoPlusUser(username, email, from, txid string) error {
221 return nil
222}
223
224func (a *AuthDb) FindUserByName(username string) (*db.User, error) {
225 return &db.User{ID: testUserID, Name: username}, nil
226}
227
228func (a *AuthDb) FindUserByKey(username string, pubkey string) (*db.User, error) {
229 return &db.User{ID: testUserID, Name: username}, nil
230}
231
232func (a *AuthDb) FindUserByToken(token string) (*db.User, error) {
233 if token != "123" {
234 return nil, fmt.Errorf("invalid token")
235 }
236 return &db.User{ID: testUserID, Name: testUsername}, nil
237}
238
239func (a *AuthDb) HasFeatureByUser(userID string, feature string) bool {
240 return true
241}
242
243func (a *AuthDb) FindKeysByUser(user *db.User) ([]*db.PublicKey, error) {
244 return []*db.PublicKey{{ID: "1", UserID: user.ID, Name: "my-key", Key: "nice-pubkey", CreatedAt: &time.Time{}}}, nil
245}
246
247func (a *AuthDb) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
248 now := time.Date(2021, 8, 15, 14, 30, 45, 100, time.UTC)
249 oneDayWarning := now.AddDate(0, 0, 1)
250 return &db.FeatureFlag{ID: "2", UserID: userID, Name: "plus", ExpiresAt: &oneDayWarning, CreatedAt: &now}, nil
251}
252
253func (a *AuthDb) InsertPost(post *db.Post) (*db.Post, error) {
254 a.Posts = append(a.Posts, post)
255 return post, nil
256}
257
258func (a *AuthDb) FindPostsByUser(pager *db.Pager, userID, space string) (*db.Paginate[*db.Post], error) {
259 return &db.Paginate[*db.Post]{
260 Data: a.Posts,
261 }, nil
262}
263
264func (a *AuthDb) UpsertToken(string, string) (string, error) {
265 return "123", nil
266}
267
268func NewAuthDb(logger *slog.Logger) *AuthDb {
269 sb := stub.NewStubDB(logger)
270 return &AuthDb{
271 StubDB: sb,
272 }
273}
274
275func mkpath(path string) string {
276 return fmt.Sprintf("https://auth.pico.test%s", path)
277}
278
279func setupTest() *router.ApiConfig {
280 logger := shared.CreateLogger("auth-test", false)
281 cfg := &shared.ConfigSite{
282 Issuer: "auth.pico.test",
283 Domain: "http://0.0.0.0:3000",
284 Port: "3000",
285 Secret: "",
286 SecretWebhook: "my-secret",
287 }
288 cfg.Logger = logger
289 db := NewAuthDb(cfg.Logger)
290 apiConfig := &router.ApiConfig{
291 Cfg: cfg,
292 Dbpool: db,
293 }
294
295 return apiConfig
296}
297
298func testResponse(t *testing.T, responseRecorder *httptest.ResponseRecorder, status int, contentType string) {
299 if responseRecorder.Code != status {
300 t.Errorf("Want status '%d', got '%d'", status, responseRecorder.Code)
301 return
302 }
303
304 ct := responseRecorder.Header().Get("content-type")
305 if ct != contentType {
306 t.Errorf("Want content type '%s', got '%s'", contentType, ct)
307 return
308 }
309
310 body := strings.TrimSpace(responseRecorder.Body.String())
311 snaps.MatchSnapshot(t, body)
312}
313
314func bail(err error) {
315 if err != nil {
316 panic(bail)
317 }
318}