Eric Bower
·
2025-04-10
storage.go
1package postgres
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log/slog"
9 "math"
10 "strings"
11 "time"
12
13 "slices"
14
15 _ "github.com/lib/pq"
16 "github.com/picosh/pico/pkg/db"
17 "github.com/picosh/utils"
18)
19
20var PAGER_SIZE = 15
21
22var SelectPost = `
23 posts.id, user_id, app_users.name, filename, slug, title, text, description,
24 posts.created_at, publish_at, posts.updated_at, hidden, file_size, mime_type, shasum, data, expires_at, views`
25
26var (
27 sqlSelectPosts = fmt.Sprintf(`
28 SELECT %s
29 FROM posts
30 LEFT JOIN app_users ON app_users.id = posts.user_id`, SelectPost)
31
32 sqlSelectPostsBeforeDate = fmt.Sprintf(`
33 SELECT %s
34 FROM posts
35 LEFT JOIN app_users ON app_users.id = posts.user_id
36 WHERE publish_at::date <= $1 AND cur_space = $2`, SelectPost)
37
38 sqlSelectPostWithFilename = fmt.Sprintf(`
39 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
40 FROM posts
41 LEFT JOIN app_users ON app_users.id = posts.user_id
42 LEFT JOIN post_tags ON post_tags.post_id = posts.id
43 WHERE filename = $1 AND user_id = $2 AND cur_space = $3
44 GROUP BY %s`, SelectPost, SelectPost)
45
46 sqlSelectPostWithSlug = fmt.Sprintf(`
47 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
48 FROM posts
49 LEFT JOIN app_users ON app_users.id = posts.user_id
50 LEFT JOIN post_tags ON post_tags.post_id = posts.id
51 WHERE slug = $1 AND user_id = $2 AND cur_space = $3
52 GROUP BY %s`, SelectPost, SelectPost)
53
54 sqlSelectPost = fmt.Sprintf(`
55 SELECT %s
56 FROM posts
57 LEFT JOIN app_users ON app_users.id = posts.user_id
58 WHERE posts.id = $1`, SelectPost)
59
60 sqlSelectUpdatedPostsForUser = fmt.Sprintf(`
61 SELECT %s
62 FROM posts
63 LEFT JOIN app_users ON app_users.id = posts.user_id
64 WHERE user_id = $1 AND publish_at::date <= CURRENT_DATE AND cur_space = $2
65 ORDER BY posts.updated_at DESC`, SelectPost)
66
67 sqlSelectExpiredPosts = fmt.Sprintf(`
68 SELECT %s
69 FROM posts
70 LEFT JOIN app_users ON app_users.id = posts.user_id
71 WHERE
72 cur_space = $1 AND
73 expires_at <= now();
74 `, SelectPost)
75
76 sqlSelectPostsForUser = fmt.Sprintf(`
77 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
78 FROM posts
79 LEFT JOIN app_users ON app_users.id = posts.user_id
80 LEFT JOIN post_tags ON post_tags.post_id = posts.id
81 WHERE
82 hidden = FALSE AND
83 user_id = $1 AND
84 publish_at::date <= CURRENT_DATE AND
85 cur_space = $2
86 GROUP BY %s
87 ORDER BY publish_at DESC, slug DESC
88 LIMIT $3 OFFSET $4`, SelectPost, SelectPost)
89
90 sqlSelectAllPostsForUser = fmt.Sprintf(`
91 SELECT %s
92 FROM posts
93 LEFT JOIN app_users ON app_users.id = posts.user_id
94 WHERE
95 user_id = $1 AND
96 cur_space = $2
97 ORDER BY publish_at DESC`, SelectPost)
98
99 sqlSelectPostsByTag = `
100 SELECT
101 posts.id,
102 user_id,
103 filename,
104 slug,
105 title,
106 text,
107 description,
108 publish_at,
109 app_users.name as username,
110 posts.updated_at,
111 posts.mime_type
112 FROM posts
113 LEFT JOIN app_users ON app_users.id = posts.user_id
114 LEFT JOIN post_tags ON post_tags.post_id = posts.id
115 WHERE
116 post_tags.name = $3 AND
117 publish_at::date <= CURRENT_DATE AND
118 cur_space = $4
119 ORDER BY publish_at DESC
120 LIMIT $1 OFFSET $2`
121
122 sqlSelectUserPostsByTag = fmt.Sprintf(`
123 SELECT %s
124 FROM posts
125 LEFT JOIN app_users ON app_users.id = posts.user_id
126 LEFT JOIN post_tags ON post_tags.post_id = posts.id
127 WHERE
128 hidden = FALSE AND
129 user_id = $1 AND
130 (post_tags.name = $2 OR hidden = true) AND
131 publish_at::date <= CURRENT_DATE AND
132 cur_space = $3
133 ORDER BY publish_at DESC
134 LIMIT $4 OFFSET $5`, SelectPost)
135)
136
137const (
138 sqlSelectPublicKey = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE public_key = $1`
139 sqlSelectPublicKeys = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE user_id = $1 ORDER BY created_at ASC`
140 sqlSelectUser = `SELECT id, name, created_at FROM app_users WHERE id = $1`
141 sqlSelectUserForName = `SELECT id, name, created_at FROM app_users WHERE name = $1`
142 sqlSelectUserForNameAndKey = `SELECT app_users.id, app_users.name, app_users.created_at, public_keys.id as pk_id, public_keys.public_key, public_keys.created_at as pk_created_at FROM app_users LEFT JOIN public_keys ON public_keys.user_id = app_users.id WHERE app_users.name = $1 AND public_keys.public_key = $2`
143 sqlSelectUsers = `SELECT id, name, created_at FROM app_users ORDER BY name ASC`
144
145 sqlSelectUserForToken = `
146 SELECT app_users.id, app_users.name, app_users.created_at
147 FROM app_users
148 LEFT JOIN tokens ON tokens.user_id = app_users.id
149 WHERE tokens.token = $1 AND tokens.expires_at > NOW()`
150 sqlInsertToken = `INSERT INTO tokens (user_id, name) VALUES($1, $2) RETURNING token;`
151 sqlRemoveToken = `DELETE FROM tokens WHERE id = $1`
152 sqlSelectTokensForUser = `SELECT id, user_id, name, created_at, expires_at FROM tokens WHERE user_id = $1`
153 sqlSelectTokenByNameForUser = `SELECT token FROM tokens WHERE user_id = $1 AND name = $2`
154
155 sqlSelectFeatureForUser = `SELECT id, user_id, payment_history_id, name, data, created_at, expires_at FROM feature_flags WHERE user_id = $1 AND name = $2 ORDER BY expires_at DESC LIMIT 1`
156 sqlSelectSizeForUser = `SELECT COALESCE(sum(file_size), 0) FROM posts WHERE user_id = $1`
157
158 sqlSelectPostIdByAliasSlug = `SELECT post_id FROM post_aliases WHERE slug = $1`
159 sqlSelectTagPostCount = `
160 SELECT count(posts.id)
161 FROM posts
162 LEFT JOIN post_tags ON post_tags.post_id = posts.id
163 WHERE hidden = FALSE AND cur_space=$1 and post_tags.name = $2`
164 sqlSelectPostCount = `SELECT count(id) FROM posts WHERE hidden = FALSE AND cur_space=$1`
165 sqlSelectAllUpdatedPosts = `
166 SELECT
167 posts.id,
168 user_id,
169 filename,
170 slug,
171 title,
172 text,
173 description,
174 publish_at,
175 app_users.name as username,
176 posts.updated_at,
177 posts.mime_type
178 FROM posts
179 LEFT JOIN app_users ON app_users.id = posts.user_id
180 WHERE hidden = FALSE AND publish_at::date <= CURRENT_DATE AND cur_space = $3
181 ORDER BY updated_at DESC
182 LIMIT $1 OFFSET $2`
183 // add some users to deny list since they are robogenerating a bunch of posts
184 // per day and are creating a lot of noise.
185 sqlSelectPostsByRank = `
186 SELECT
187 posts.id,
188 user_id,
189 filename,
190 slug,
191 title,
192 text,
193 description,
194 publish_at,
195 app_users.name as username,
196 posts.updated_at,
197 posts.mime_type
198 FROM posts
199 LEFT JOIN app_users ON app_users.id = posts.user_id
200 WHERE
201 hidden = FALSE AND
202 publish_at::date <= CURRENT_DATE AND
203 cur_space = $3 AND
204 app_users.name NOT IN ('algiegray', 'mrrccc')
205 ORDER BY publish_at DESC
206 LIMIT $1 OFFSET $2`
207
208 sqlSelectPopularTags = `
209 SELECT name, count(post_id) as "tally"
210 FROM post_tags
211 LEFT JOIN posts ON posts.id = post_id
212 WHERE posts.cur_space = $1
213 GROUP BY name
214 ORDER BY tally DESC
215 LIMIT 5`
216 sqlSelectTagsForUser = `
217 SELECT name
218 FROM post_tags
219 LEFT JOIN posts ON posts.id = post_id
220 WHERE posts.user_id = $1 AND posts.cur_space = $2
221 GROUP BY name`
222 sqlSelectTagsForPost = `SELECT name FROM post_tags WHERE post_id=$1`
223 sqlSelectFeedItemsByPost = `SELECT id, post_id, guid, data, created_at FROM feed_items WHERE post_id=$1`
224
225 sqlInsertPublicKey = `INSERT INTO public_keys (user_id, public_key) VALUES ($1, $2)`
226 sqlInsertPost = `
227 INSERT INTO posts
228 (user_id, filename, slug, title, text, description, publish_at, hidden, cur_space,
229 file_size, mime_type, shasum, data, expires_at, updated_at)
230 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
231 RETURNING id`
232 sqlInsertUser = `INSERT INTO app_users (name) VALUES($1) returning id`
233 sqlInsertTag = `INSERT INTO post_tags (post_id, name) VALUES($1, $2) RETURNING id;`
234 sqlInsertAliases = `INSERT INTO post_aliases (post_id, slug) VALUES($1, $2) RETURNING id;`
235 sqlInsertFeedItems = `INSERT INTO feed_items (post_id, guid, data) VALUES ($1, $2, $3) RETURNING id;`
236
237 sqlUpdatePost = `
238 UPDATE posts
239 SET slug = $1, title = $2, text = $3, description = $4, updated_at = $5, publish_at = $6,
240 file_size = $7, shasum = $8, data = $9, hidden = $11, expires_at = $12
241 WHERE id = $10`
242 sqlUpdateUserName = `UPDATE app_users SET name = $1 WHERE id = $2`
243 sqlIncrementViews = `UPDATE posts SET views = views + 1 WHERE id = $1 RETURNING views`
244
245 sqlRemoveAliasesByPost = `DELETE FROM post_aliases WHERE post_id = $1`
246 sqlRemoveTagsByPost = `DELETE FROM post_tags WHERE post_id = $1`
247 sqlRemovePosts = `DELETE FROM posts WHERE id = ANY($1::uuid[])`
248 sqlRemoveKeys = `DELETE FROM public_keys WHERE id = ANY($1::uuid[])`
249 sqlRemoveUsers = `DELETE FROM app_users WHERE id = ANY($1::uuid[])`
250
251 sqlInsertProject = `INSERT INTO projects (user_id, name, project_dir) VALUES ($1, $2, $3) RETURNING id;`
252 sqlUpdateProject = `UPDATE projects SET updated_at = $3 WHERE user_id = $1 AND name = $2;`
253 sqlFindProjectByName = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = $2;`
254 sqlSelectProjectCount = `SELECT count(id) FROM projects`
255 sqlFindProjectsByUser = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 ORDER BY name ASC, updated_at DESC;`
256 sqlFindProjectsByPrefix = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = project_dir AND name ILIKE $2 ORDER BY updated_at ASC, name ASC;`
257 sqlFindProjectLinks = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name != project_dir AND project_dir = $2 ORDER BY name ASC;`
258 sqlLinkToProject = `UPDATE projects SET project_dir = $1, updated_at = $2 WHERE id = $3;`
259 sqlRemoveProject = `DELETE FROM projects WHERE id = $1;`
260)
261
262type PsqlDB struct {
263 Logger *slog.Logger
264 Db *sql.DB
265}
266
267type RowScanner interface {
268 Scan(dest ...any) error
269}
270
271func CreatePostFromRow(r RowScanner) (*db.Post, error) {
272 post := &db.Post{}
273 err := r.Scan(
274 &post.ID,
275 &post.UserID,
276 &post.Username,
277 &post.Filename,
278 &post.Slug,
279 &post.Title,
280 &post.Text,
281 &post.Description,
282 &post.CreatedAt,
283 &post.PublishAt,
284 &post.UpdatedAt,
285 &post.Hidden,
286 &post.FileSize,
287 &post.MimeType,
288 &post.Shasum,
289 &post.Data,
290 &post.ExpiresAt,
291 &post.Views,
292 )
293 if err != nil {
294 return nil, err
295 }
296 return post, nil
297}
298
299func CreatePostWithTagsFromRow(r RowScanner) (*db.Post, error) {
300 post := &db.Post{}
301 tagStr := ""
302 err := r.Scan(
303 &post.ID,
304 &post.UserID,
305 &post.Username,
306 &post.Filename,
307 &post.Slug,
308 &post.Title,
309 &post.Text,
310 &post.Description,
311 &post.CreatedAt,
312 &post.PublishAt,
313 &post.UpdatedAt,
314 &post.Hidden,
315 &post.FileSize,
316 &post.MimeType,
317 &post.Shasum,
318 &post.Data,
319 &post.ExpiresAt,
320 &post.Views,
321 &tagStr,
322 )
323 if err != nil {
324 return nil, err
325 }
326
327 tags := strings.Split(tagStr, ",")
328 for _, tag := range tags {
329 tg := strings.TrimSpace(tag)
330 if tg == "" {
331 continue
332 }
333 post.Tags = append(post.Tags, tg)
334 }
335
336 return post, nil
337}
338
339func NewDB(databaseUrl string, logger *slog.Logger) *PsqlDB {
340 var err error
341 d := &PsqlDB{
342 Logger: logger,
343 }
344 d.Logger.Info("Connecting to postgres", "databaseUrl", databaseUrl)
345
346 db, err := sql.Open("postgres", databaseUrl)
347 if err != nil {
348 d.Logger.Error(err.Error())
349 }
350 d.Db = db
351 return d
352}
353
354func (me *PsqlDB) RegisterUser(username, pubkey, comment string) (*db.User, error) {
355 lowerName := strings.ToLower(username)
356 valid, err := me.ValidateName(lowerName)
357 if !valid {
358 return nil, err
359 }
360
361 ctx := context.Background()
362 tx, err := me.Db.BeginTx(ctx, nil)
363 if err != nil {
364 return nil, err
365 }
366 defer func() {
367 err = tx.Rollback()
368 }()
369
370 stmt, err := tx.Prepare(sqlInsertUser)
371 if err != nil {
372 return nil, err
373 }
374 defer stmt.Close()
375
376 var id string
377 err = stmt.QueryRow(lowerName).Scan(&id)
378 if err != nil {
379 return nil, err
380 }
381
382 err = me.InsertPublicKey(id, pubkey, comment, tx)
383 if err != nil {
384 return nil, err
385 }
386
387 err = tx.Commit()
388 if err != nil {
389 return nil, err
390 }
391
392 return me.FindUserForKey(username, pubkey)
393}
394
395func (me *PsqlDB) RemoveUsers(userIDs []string) error {
396 param := "{" + strings.Join(userIDs, ",") + "}"
397 _, err := me.Db.Exec(sqlRemoveUsers, param)
398 return err
399}
400
401func (me *PsqlDB) InsertPublicKey(userID, key, name string, tx *sql.Tx) error {
402 pk, _ := me.FindPublicKeyForKey(key)
403 if pk != nil {
404 return db.ErrPublicKeyTaken
405 }
406 query := `INSERT INTO public_keys (user_id, public_key, name) VALUES ($1, $2, $3)`
407 var err error
408 if tx != nil {
409 _, err = tx.Exec(query, userID, key, name)
410 } else {
411 _, err = me.Db.Exec(query, userID, key, name)
412 }
413 if err != nil {
414 return err
415 }
416
417 return nil
418}
419
420func (me *PsqlDB) UpdatePublicKey(pubkeyID, name string) (*db.PublicKey, error) {
421 pk, err := me.FindPublicKey(pubkeyID)
422 if err != nil {
423 return nil, err
424 }
425
426 query := `UPDATE public_keys SET name=$1 WHERE id=$2;`
427 _, err = me.Db.Exec(query, name, pk.ID)
428 if err != nil {
429 return nil, err
430 }
431
432 pk, err = me.FindPublicKey(pubkeyID)
433 if err != nil {
434 return nil, err
435 }
436 return pk, nil
437}
438
439func (me *PsqlDB) FindPublicKeyForKey(key string) (*db.PublicKey, error) {
440 var keys []*db.PublicKey
441 rs, err := me.Db.Query(sqlSelectPublicKey, key)
442 if err != nil {
443 return nil, err
444 }
445
446 for rs.Next() {
447 pk := &db.PublicKey{}
448 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
449 if err != nil {
450 return nil, err
451 }
452
453 keys = append(keys, pk)
454 }
455
456 if rs.Err() != nil {
457 return nil, rs.Err()
458 }
459
460 if len(keys) == 0 {
461 return nil, fmt.Errorf("pubkey not found in our database: [%s]", key)
462 }
463
464 // When we run PublicKeyForKey and there are multiple public keys returned from the database
465 // that should mean that we don't have the correct username for this public key.
466 // When that happens we need to reject the authentication and ask the user to provide the correct
467 // username when using ssh. So instead of `ssh <domain>` it should be `ssh user@<domain>`
468 if len(keys) > 1 {
469 return nil, &db.ErrMultiplePublicKeys{}
470 }
471
472 return keys[0], nil
473}
474
475func (me *PsqlDB) FindPublicKey(pubkeyID string) (*db.PublicKey, error) {
476 var keys []*db.PublicKey
477 rs, err := me.Db.Query(`SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE id = $1`, pubkeyID)
478 if err != nil {
479 return nil, err
480 }
481
482 for rs.Next() {
483 pk := &db.PublicKey{}
484 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
485 if err != nil {
486 return nil, err
487 }
488
489 keys = append(keys, pk)
490 }
491
492 if rs.Err() != nil {
493 return nil, rs.Err()
494 }
495
496 if len(keys) == 0 {
497 return nil, errors.New("no public keys found for key provided")
498 }
499
500 return keys[0], nil
501}
502
503func (me *PsqlDB) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
504 var keys []*db.PublicKey
505 rs, err := me.Db.Query(sqlSelectPublicKeys, user.ID)
506 if err != nil {
507 return keys, err
508 }
509 for rs.Next() {
510 pk := &db.PublicKey{}
511 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
512 if err != nil {
513 return keys, err
514 }
515
516 keys = append(keys, pk)
517 }
518 if rs.Err() != nil {
519 return keys, rs.Err()
520 }
521 return keys, nil
522}
523
524func (me *PsqlDB) RemoveKeys(keyIDs []string) error {
525 param := "{" + strings.Join(keyIDs, ",") + "}"
526 _, err := me.Db.Exec(sqlRemoveKeys, param)
527 return err
528}
529
530func (me *PsqlDB) FindPostsBeforeDate(date *time.Time, space string) ([]*db.Post, error) {
531 // now := time.Now()
532 // expired := now.AddDate(0, 0, -3)
533 var posts []*db.Post
534 rs, err := me.Db.Query(sqlSelectPostsBeforeDate, date, space)
535 if err != nil {
536 return posts, err
537 }
538 for rs.Next() {
539 post, err := CreatePostFromRow(rs)
540 if err != nil {
541 return nil, err
542 }
543
544 posts = append(posts, post)
545 }
546 if rs.Err() != nil {
547 return posts, rs.Err()
548 }
549 return posts, nil
550}
551
552func (me *PsqlDB) FindUserForKey(username string, key string) (*db.User, error) {
553 me.Logger.Info("attempting to find user with only public key", "key", key)
554 pk, err := me.FindPublicKeyForKey(key)
555 if err == nil {
556 me.Logger.Info("found pubkey, looking for user", "key", key, "userId", pk.UserID)
557 user, err := me.FindUser(pk.UserID)
558 if err != nil {
559 return nil, err
560 }
561 user.PublicKey = pk
562 return user, nil
563 }
564
565 if errors.Is(err, &db.ErrMultiplePublicKeys{}) {
566 me.Logger.Info("detected multiple users with same public key", "user", username)
567 user, err := me.FindUserForNameAndKey(username, key)
568 if err != nil {
569 me.Logger.Info("could not find user by username and public key", "user", username, "key", key)
570 // this is a little hacky but if we cannot find a user by name and public key
571 // then we return the multiple keys detected error so the user knows to specify their
572 // when logging in
573 return nil, &db.ErrMultiplePublicKeys{}
574 }
575 return user, nil
576 }
577
578 return nil, err
579}
580
581func (me *PsqlDB) FindUserByPubkey(key string) (*db.User, error) {
582 me.Logger.Info("attempting to find user with only public key", "key", key)
583 pk, err := me.FindPublicKeyForKey(key)
584 if err != nil {
585 return nil, err
586 }
587
588 me.Logger.Info("found pubkey, looking for user", "key", key, "userId", pk.UserID)
589 user, err := me.FindUser(pk.UserID)
590 if err != nil {
591 return nil, err
592 }
593 user.PublicKey = pk
594 return user, nil
595}
596
597func (me *PsqlDB) FindUser(userID string) (*db.User, error) {
598 user := &db.User{}
599 var un sql.NullString
600 r := me.Db.QueryRow(sqlSelectUser, userID)
601 err := r.Scan(&user.ID, &un, &user.CreatedAt)
602 if err != nil {
603 return nil, err
604 }
605 if un.Valid {
606 user.Name = un.String
607 }
608 return user, nil
609}
610
611func (me *PsqlDB) ValidateName(name string) (bool, error) {
612 lower := strings.ToLower(name)
613 if slices.Contains(db.DenyList, lower) {
614 return false, fmt.Errorf("%s is on deny list: %w", lower, db.ErrNameDenied)
615 }
616 v := db.NameValidator.MatchString(lower)
617 if !v {
618 return false, fmt.Errorf("%s is invalid: %w", lower, db.ErrNameInvalid)
619 }
620 user, _ := me.FindUserByName(lower)
621 if user == nil {
622 return true, nil
623 }
624 return false, fmt.Errorf("%s already taken: %w", lower, db.ErrNameTaken)
625}
626
627func (me *PsqlDB) FindUserByName(name string) (*db.User, error) {
628 user := &db.User{}
629 r := me.Db.QueryRow(sqlSelectUserForName, strings.ToLower(name))
630 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
631 if err != nil {
632 return nil, err
633 }
634 return user, nil
635}
636
637func (me *PsqlDB) FindUserForNameAndKey(name string, key string) (*db.User, error) {
638 user := &db.User{}
639 pk := &db.PublicKey{}
640
641 r := me.Db.QueryRow(sqlSelectUserForNameAndKey, strings.ToLower(name), key)
642 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt, &pk.ID, &pk.Key, &pk.CreatedAt)
643 if err != nil {
644 return nil, err
645 }
646
647 user.PublicKey = pk
648 return user, nil
649}
650
651func (me *PsqlDB) FindUserForToken(token string) (*db.User, error) {
652 user := &db.User{}
653
654 r := me.Db.QueryRow(sqlSelectUserForToken, token)
655 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
656 if err != nil {
657 return nil, err
658 }
659
660 return user, nil
661}
662
663func (me *PsqlDB) SetUserName(userID string, name string) error {
664 lowerName := strings.ToLower(name)
665 valid, err := me.ValidateName(lowerName)
666 if !valid {
667 return err
668 }
669
670 _, err = me.Db.Exec(sqlUpdateUserName, lowerName, userID)
671 return err
672}
673
674func (me *PsqlDB) FindPostWithFilename(filename string, persona_id string, space string) (*db.Post, error) {
675 r := me.Db.QueryRow(sqlSelectPostWithFilename, filename, persona_id, space)
676 post, err := CreatePostWithTagsFromRow(r)
677 if err != nil {
678 return nil, err
679 }
680
681 return post, nil
682}
683
684func (me *PsqlDB) FindPostWithSlug(slug string, user_id string, space string) (*db.Post, error) {
685 r := me.Db.QueryRow(sqlSelectPostWithSlug, slug, user_id, space)
686 post, err := CreatePostWithTagsFromRow(r)
687 if err != nil {
688 // attempt to find post inside post_aliases
689 alias := me.Db.QueryRow(sqlSelectPostIdByAliasSlug, slug)
690 postID := ""
691 err := alias.Scan(&postID)
692 if err != nil {
693 return nil, err
694 }
695
696 return me.FindPost(postID)
697 }
698
699 return post, nil
700}
701
702func (me *PsqlDB) FindPost(postID string) (*db.Post, error) {
703 r := me.Db.QueryRow(sqlSelectPost, postID)
704 post, err := CreatePostFromRow(r)
705 if err != nil {
706 return nil, err
707 }
708
709 return post, nil
710}
711
712func (me *PsqlDB) postPager(rs *sql.Rows, pageNum int, space string, tag string) (*db.Paginate[*db.Post], error) {
713 var posts []*db.Post
714 for rs.Next() {
715 post := &db.Post{}
716 err := rs.Scan(
717 &post.ID,
718 &post.UserID,
719 &post.Filename,
720 &post.Slug,
721 &post.Title,
722 &post.Text,
723 &post.Description,
724 &post.PublishAt,
725 &post.Username,
726 &post.UpdatedAt,
727 &post.MimeType,
728 )
729 if err != nil {
730 return nil, err
731 }
732
733 posts = append(posts, post)
734 }
735 if rs.Err() != nil {
736 return nil, rs.Err()
737 }
738
739 var count int
740 var err error
741 if tag == "" {
742 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
743 } else {
744 err = me.Db.QueryRow(sqlSelectTagPostCount, space, tag).Scan(&count)
745 }
746 if err != nil {
747 return nil, err
748 }
749
750 pager := &db.Paginate[*db.Post]{
751 Data: posts,
752 Total: int(math.Ceil(float64(count) / float64(pageNum))),
753 }
754
755 return pager, nil
756}
757
758func (me *PsqlDB) FindAllPosts(page *db.Pager, space string) (*db.Paginate[*db.Post], error) {
759 rs, err := me.Db.Query(sqlSelectPostsByRank, page.Num, page.Num*page.Page, space)
760 if err != nil {
761 return nil, err
762 }
763 return me.postPager(rs, page.Num, space, "")
764}
765
766func (me *PsqlDB) FindAllUpdatedPosts(page *db.Pager, space string) (*db.Paginate[*db.Post], error) {
767 rs, err := me.Db.Query(sqlSelectAllUpdatedPosts, page.Num, page.Num*page.Page, space)
768 if err != nil {
769 return nil, err
770 }
771 return me.postPager(rs, page.Num, space, "")
772}
773
774func (me *PsqlDB) InsertPost(post *db.Post) (*db.Post, error) {
775 var id string
776 err := me.Db.QueryRow(
777 sqlInsertPost,
778 post.UserID,
779 post.Filename,
780 post.Slug,
781 post.Title,
782 post.Text,
783 post.Description,
784 post.PublishAt,
785 post.Hidden,
786 post.Space,
787 post.FileSize,
788 post.MimeType,
789 post.Shasum,
790 post.Data,
791 post.ExpiresAt,
792 post.UpdatedAt,
793 ).Scan(&id)
794 if err != nil {
795 return nil, err
796 }
797
798 return me.FindPost(id)
799}
800
801func (me *PsqlDB) UpdatePost(post *db.Post) (*db.Post, error) {
802 _, err := me.Db.Exec(
803 sqlUpdatePost,
804 post.Slug,
805 post.Title,
806 post.Text,
807 post.Description,
808 post.UpdatedAt,
809 post.PublishAt,
810 post.FileSize,
811 post.Shasum,
812 post.Data,
813 post.ID,
814 post.Hidden,
815 post.ExpiresAt,
816 )
817 if err != nil {
818 return nil, err
819 }
820
821 return me.FindPost(post.ID)
822}
823
824func (me *PsqlDB) RemovePosts(postIDs []string) error {
825 param := "{" + strings.Join(postIDs, ",") + "}"
826 _, err := me.Db.Exec(sqlRemovePosts, param)
827 return err
828}
829
830func (me *PsqlDB) FindPostsForUser(page *db.Pager, userID string, space string) (*db.Paginate[*db.Post], error) {
831 var posts []*db.Post
832 rs, err := me.Db.Query(
833 sqlSelectPostsForUser,
834 userID,
835 space,
836 page.Num,
837 page.Num*page.Page,
838 )
839 if err != nil {
840 return nil, err
841 }
842 for rs.Next() {
843 post, err := CreatePostWithTagsFromRow(rs)
844 if err != nil {
845 return nil, err
846 }
847
848 posts = append(posts, post)
849 }
850
851 if rs.Err() != nil {
852 return nil, rs.Err()
853 }
854
855 var count int
856 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
857 if err != nil {
858 return nil, err
859 }
860
861 pager := &db.Paginate[*db.Post]{
862 Data: posts,
863 Total: int(math.Ceil(float64(count) / float64(page.Num))),
864 }
865 return pager, nil
866}
867
868func (me *PsqlDB) FindAllPostsForUser(userID string, space string) ([]*db.Post, error) {
869 var posts []*db.Post
870 rs, err := me.Db.Query(sqlSelectAllPostsForUser, userID, space)
871 if err != nil {
872 return posts, err
873 }
874 for rs.Next() {
875 post, err := CreatePostFromRow(rs)
876 if err != nil {
877 return nil, err
878 }
879
880 posts = append(posts, post)
881 }
882 if rs.Err() != nil {
883 return posts, rs.Err()
884 }
885 return posts, nil
886}
887
888func (me *PsqlDB) FindPosts() ([]*db.Post, error) {
889 var posts []*db.Post
890 rs, err := me.Db.Query(sqlSelectPosts)
891 if err != nil {
892 return posts, err
893 }
894 for rs.Next() {
895 post, err := CreatePostFromRow(rs)
896 if err != nil {
897 return nil, err
898 }
899
900 posts = append(posts, post)
901 }
902 if rs.Err() != nil {
903 return posts, rs.Err()
904 }
905 return posts, nil
906}
907
908func (me *PsqlDB) FindExpiredPosts(space string) ([]*db.Post, error) {
909 var posts []*db.Post
910 rs, err := me.Db.Query(sqlSelectExpiredPosts, space)
911 if err != nil {
912 return posts, err
913 }
914 for rs.Next() {
915 post, err := CreatePostFromRow(rs)
916 if err != nil {
917 return nil, err
918 }
919
920 posts = append(posts, post)
921 }
922 if rs.Err() != nil {
923 return posts, rs.Err()
924 }
925 return posts, nil
926}
927
928func (me *PsqlDB) FindUpdatedPostsForUser(userID string, space string) ([]*db.Post, error) {
929 var posts []*db.Post
930 rs, err := me.Db.Query(sqlSelectUpdatedPostsForUser, userID, space)
931 if err != nil {
932 return posts, err
933 }
934 for rs.Next() {
935 post, err := CreatePostFromRow(rs)
936 if err != nil {
937 return nil, err
938 }
939
940 posts = append(posts, post)
941 }
942 if rs.Err() != nil {
943 return posts, rs.Err()
944 }
945 return posts, nil
946}
947
948func (me *PsqlDB) Close() error {
949 me.Logger.Info("Closing db")
950 return me.Db.Close()
951}
952
953func newNullString(s string) sql.NullString {
954 if len(s) == 0 {
955 return sql.NullString{}
956 }
957 return sql.NullString{
958 String: s,
959 Valid: true,
960 }
961}
962
963func (me *PsqlDB) InsertVisit(visit *db.AnalyticsVisits) error {
964 _, err := me.Db.Exec(
965 `INSERT INTO analytics_visits (user_id, project_id, post_id, namespace, host, path, ip_address, user_agent, referer, status, content_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);`,
966 visit.UserID,
967 newNullString(visit.ProjectID),
968 newNullString(visit.PostID),
969 newNullString(visit.Namespace),
970 visit.Host,
971 visit.Path,
972 visit.IpAddress,
973 visit.UserAgent,
974 visit.Referer,
975 visit.Status,
976 visit.ContentType,
977 )
978 return err
979}
980
981func visitFilterBy(opts *db.SummaryOpts) (string, string) {
982 where := ""
983 val := ""
984 if opts.Host != "" {
985 where = "host"
986 val = opts.Host
987 } else if opts.Path != "" {
988 where = "path"
989 val = opts.Path
990 }
991
992 return where, val
993}
994
995func (me *PsqlDB) visitUnique(opts *db.SummaryOpts) ([]*db.VisitInterval, error) {
996 where, with := visitFilterBy(opts)
997 uniqueVisitors := fmt.Sprintf(`SELECT
998 date_trunc('%s', created_at) as interval_start,
999 count(DISTINCT ip_address) as unique_visitors
1000 FROM analytics_visits
1001 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND status <> 404
1002 GROUP BY interval_start`, opts.Interval, where)
1003
1004 intervals := []*db.VisitInterval{}
1005 rs, err := me.Db.Query(uniqueVisitors, opts.Origin, with, opts.UserID)
1006 if err != nil {
1007 return nil, err
1008 }
1009
1010 for rs.Next() {
1011 interval := &db.VisitInterval{}
1012 err := rs.Scan(
1013 &interval.Interval,
1014 &interval.Visitors,
1015 )
1016 if err != nil {
1017 return nil, err
1018 }
1019
1020 intervals = append(intervals, interval)
1021 }
1022 if rs.Err() != nil {
1023 return nil, rs.Err()
1024 }
1025 return intervals, nil
1026}
1027
1028func (me *PsqlDB) visitReferer(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1029 where, with := visitFilterBy(opts)
1030 topUrls := fmt.Sprintf(`SELECT
1031 referer,
1032 count(DISTINCT ip_address) as referer_count
1033 FROM analytics_visits
1034 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND referer <> '' AND status <> 404
1035 GROUP BY referer
1036 ORDER BY referer_count DESC
1037 LIMIT 10`, where)
1038
1039 intervals := []*db.VisitUrl{}
1040 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1041 if err != nil {
1042 return nil, err
1043 }
1044
1045 for rs.Next() {
1046 interval := &db.VisitUrl{}
1047 err := rs.Scan(
1048 &interval.Url,
1049 &interval.Count,
1050 )
1051 if err != nil {
1052 return nil, err
1053 }
1054
1055 intervals = append(intervals, interval)
1056 }
1057 if rs.Err() != nil {
1058 return nil, rs.Err()
1059 }
1060 return intervals, nil
1061}
1062
1063func (me *PsqlDB) visitUrl(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1064 where, with := visitFilterBy(opts)
1065 topUrls := fmt.Sprintf(`SELECT
1066 path,
1067 count(DISTINCT ip_address) as path_count
1068 FROM analytics_visits
1069 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND path <> '' AND status <> 404
1070 GROUP BY path
1071 ORDER BY path_count DESC
1072 LIMIT 10`, where)
1073
1074 intervals := []*db.VisitUrl{}
1075 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1076 if err != nil {
1077 return nil, err
1078 }
1079
1080 for rs.Next() {
1081 interval := &db.VisitUrl{}
1082 err := rs.Scan(
1083 &interval.Url,
1084 &interval.Count,
1085 )
1086 if err != nil {
1087 return nil, err
1088 }
1089
1090 intervals = append(intervals, interval)
1091 }
1092 if rs.Err() != nil {
1093 return nil, rs.Err()
1094 }
1095 return intervals, nil
1096}
1097
1098func (me *PsqlDB) visitUrlNotFound(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1099 where, with := visitFilterBy(opts)
1100 topUrls := fmt.Sprintf(`SELECT
1101 path,
1102 count(DISTINCT ip_address) as path_count
1103 FROM analytics_visits
1104 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND path <> '' AND status = 404
1105 GROUP BY path
1106 ORDER BY path_count DESC
1107 LIMIT 10`, where)
1108
1109 intervals := []*db.VisitUrl{}
1110 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1111 if err != nil {
1112 return nil, err
1113 }
1114
1115 for rs.Next() {
1116 interval := &db.VisitUrl{}
1117 err := rs.Scan(
1118 &interval.Url,
1119 &interval.Count,
1120 )
1121 if err != nil {
1122 return nil, err
1123 }
1124
1125 intervals = append(intervals, interval)
1126 }
1127 if rs.Err() != nil {
1128 return nil, rs.Err()
1129 }
1130 return intervals, nil
1131}
1132
1133func (me *PsqlDB) visitHost(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1134 topUrls := `SELECT
1135 host,
1136 count(DISTINCT ip_address) as host_count
1137 FROM analytics_visits
1138 WHERE user_id = $1 AND host <> ''
1139 GROUP BY host
1140 ORDER BY host_count DESC`
1141
1142 intervals := []*db.VisitUrl{}
1143 rs, err := me.Db.Query(topUrls, opts.UserID)
1144 if err != nil {
1145 return nil, err
1146 }
1147
1148 for rs.Next() {
1149 interval := &db.VisitUrl{}
1150 err := rs.Scan(
1151 &interval.Url,
1152 &interval.Count,
1153 )
1154 if err != nil {
1155 return nil, err
1156 }
1157
1158 intervals = append(intervals, interval)
1159 }
1160 if rs.Err() != nil {
1161 return nil, rs.Err()
1162 }
1163 return intervals, nil
1164}
1165
1166func (me *PsqlDB) VisitSummary(opts *db.SummaryOpts) (*db.SummaryVisits, error) {
1167 visitors, err := me.visitUnique(opts)
1168 if err != nil {
1169 return nil, err
1170 }
1171
1172 urls, err := me.visitUrl(opts)
1173 if err != nil {
1174 return nil, err
1175 }
1176
1177 notFound, err := me.visitUrlNotFound(opts)
1178 if err != nil {
1179 return nil, err
1180 }
1181
1182 refs, err := me.visitReferer(opts)
1183 if err != nil {
1184 return nil, err
1185 }
1186
1187 return &db.SummaryVisits{
1188 Intervals: visitors,
1189 TopUrls: urls,
1190 TopReferers: refs,
1191 NotFoundUrls: notFound,
1192 }, nil
1193}
1194
1195func (me *PsqlDB) FindVisitSiteList(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1196 return me.visitHost(opts)
1197}
1198
1199func (me *PsqlDB) FindUsers() ([]*db.User, error) {
1200 var users []*db.User
1201 rs, err := me.Db.Query(sqlSelectUsers)
1202 if err != nil {
1203 return users, err
1204 }
1205 for rs.Next() {
1206 var name sql.NullString
1207 user := &db.User{}
1208 err := rs.Scan(
1209 &user.ID,
1210 &name,
1211 &user.CreatedAt,
1212 )
1213 if err != nil {
1214 return users, err
1215 }
1216 user.Name = name.String
1217
1218 users = append(users, user)
1219 }
1220 if rs.Err() != nil {
1221 return users, rs.Err()
1222 }
1223 return users, nil
1224}
1225
1226func (me *PsqlDB) removeTagsForPost(tx *sql.Tx, postID string) error {
1227 _, err := tx.Exec(sqlRemoveTagsByPost, postID)
1228 return err
1229}
1230
1231func (me *PsqlDB) insertTagsForPost(tx *sql.Tx, tags []string, postID string) ([]string, error) {
1232 ids := make([]string, 0)
1233 for _, tag := range tags {
1234 id := ""
1235 err := tx.QueryRow(sqlInsertTag, postID, tag).Scan(&id)
1236 if err != nil {
1237 return nil, err
1238 }
1239 ids = append(ids, id)
1240 }
1241
1242 return ids, nil
1243}
1244
1245func (me *PsqlDB) ReplaceTagsForPost(tags []string, postID string) error {
1246 ctx := context.Background()
1247 tx, err := me.Db.BeginTx(ctx, nil)
1248 if err != nil {
1249 return err
1250 }
1251 defer func() {
1252 err = tx.Rollback()
1253 }()
1254
1255 err = me.removeTagsForPost(tx, postID)
1256 if err != nil {
1257 return err
1258 }
1259
1260 _, err = me.insertTagsForPost(tx, tags, postID)
1261 if err != nil {
1262 return err
1263 }
1264
1265 err = tx.Commit()
1266 return err
1267}
1268
1269func (me *PsqlDB) removeAliasesForPost(tx *sql.Tx, postID string) error {
1270 _, err := tx.Exec(sqlRemoveAliasesByPost, postID)
1271 return err
1272}
1273
1274func (me *PsqlDB) insertAliasesForPost(tx *sql.Tx, aliases []string, postID string) ([]string, error) {
1275 // hardcoded
1276 denyList := []string{
1277 "rss",
1278 "rss.xml",
1279 "rss.atom",
1280 "atom.xml",
1281 "feed.xml",
1282 "smol.css",
1283 "main.css",
1284 "syntax.css",
1285 "card.png",
1286 "favicon-16x16.png",
1287 "favicon-32x32.png",
1288 "apple-touch-icon.png",
1289 "favicon.ico",
1290 "robots.txt",
1291 "atom",
1292 "blog/index.xml",
1293 }
1294
1295 ids := make([]string, 0)
1296 for _, alias := range aliases {
1297 if slices.Contains(denyList, alias) {
1298 me.Logger.Info(
1299 "name is in the deny list for aliases because it conflicts with a static route, skipping",
1300 "alias", alias,
1301 )
1302 continue
1303 }
1304 id := ""
1305 err := tx.QueryRow(sqlInsertAliases, postID, alias).Scan(&id)
1306 if err != nil {
1307 return nil, err
1308 }
1309 ids = append(ids, id)
1310 }
1311
1312 return ids, nil
1313}
1314
1315func (me *PsqlDB) ReplaceAliasesForPost(aliases []string, postID string) error {
1316 ctx := context.Background()
1317 tx, err := me.Db.BeginTx(ctx, nil)
1318 if err != nil {
1319 return err
1320 }
1321 defer func() {
1322 err = tx.Rollback()
1323 }()
1324
1325 err = me.removeAliasesForPost(tx, postID)
1326 if err != nil {
1327 return err
1328 }
1329
1330 _, err = me.insertAliasesForPost(tx, aliases, postID)
1331 if err != nil {
1332 return err
1333 }
1334
1335 err = tx.Commit()
1336 return err
1337}
1338
1339func (me *PsqlDB) FindUserPostsByTag(page *db.Pager, tag, userID, space string) (*db.Paginate[*db.Post], error) {
1340 var posts []*db.Post
1341 rs, err := me.Db.Query(
1342 sqlSelectUserPostsByTag,
1343 userID,
1344 tag,
1345 space,
1346 page.Num,
1347 page.Num*page.Page,
1348 )
1349 if err != nil {
1350 return nil, err
1351 }
1352 for rs.Next() {
1353 post, err := CreatePostFromRow(rs)
1354 if err != nil {
1355 return nil, err
1356 }
1357
1358 posts = append(posts, post)
1359 }
1360
1361 if rs.Err() != nil {
1362 return nil, rs.Err()
1363 }
1364
1365 var count int
1366 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
1367 if err != nil {
1368 return nil, err
1369 }
1370
1371 pager := &db.Paginate[*db.Post]{
1372 Data: posts,
1373 Total: int(math.Ceil(float64(count) / float64(page.Num))),
1374 }
1375 return pager, nil
1376}
1377
1378func (me *PsqlDB) FindPostsByTag(pager *db.Pager, tag, space string) (*db.Paginate[*db.Post], error) {
1379 rs, err := me.Db.Query(
1380 sqlSelectPostsByTag,
1381 pager.Num,
1382 pager.Num*pager.Page,
1383 tag,
1384 space,
1385 )
1386 if err != nil {
1387 return nil, err
1388 }
1389
1390 return me.postPager(rs, pager.Num, space, tag)
1391}
1392
1393func (me *PsqlDB) FindPopularTags(space string) ([]string, error) {
1394 tags := make([]string, 0)
1395 rs, err := me.Db.Query(sqlSelectPopularTags, space)
1396 if err != nil {
1397 return tags, err
1398 }
1399 for rs.Next() {
1400 name := ""
1401 tally := 0
1402 err := rs.Scan(&name, &tally)
1403 if err != nil {
1404 return tags, err
1405 }
1406
1407 tags = append(tags, name)
1408 }
1409 if rs.Err() != nil {
1410 return tags, rs.Err()
1411 }
1412 return tags, nil
1413}
1414
1415func (me *PsqlDB) FindTagsForUser(userID string, space string) ([]string, error) {
1416 tags := []string{}
1417 rs, err := me.Db.Query(sqlSelectTagsForUser, userID, space)
1418 if err != nil {
1419 return tags, err
1420 }
1421 for rs.Next() {
1422 name := ""
1423 err := rs.Scan(&name)
1424 if err != nil {
1425 return tags, err
1426 }
1427
1428 tags = append(tags, name)
1429 }
1430 if rs.Err() != nil {
1431 return tags, rs.Err()
1432 }
1433 return tags, nil
1434}
1435
1436func (me *PsqlDB) FindTagsForPost(postID string) ([]string, error) {
1437 tags := make([]string, 0)
1438 rs, err := me.Db.Query(sqlSelectTagsForPost, postID)
1439 if err != nil {
1440 return tags, err
1441 }
1442
1443 for rs.Next() {
1444 name := ""
1445 err := rs.Scan(&name)
1446 if err != nil {
1447 return tags, err
1448 }
1449
1450 tags = append(tags, name)
1451 }
1452
1453 if rs.Err() != nil {
1454 return tags, rs.Err()
1455 }
1456
1457 return tags, nil
1458}
1459
1460func (me *PsqlDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
1461 ff := &db.FeatureFlag{}
1462 // payment history is allowed to be null
1463 // https://devtidbits.com/2020/08/03/go-sql-error-converting-null-to-string-is-unsupported/
1464 var paymentHistoryID sql.NullString
1465 err := me.Db.QueryRow(sqlSelectFeatureForUser, userID, feature).Scan(
1466 &ff.ID,
1467 &ff.UserID,
1468 &paymentHistoryID,
1469 &ff.Name,
1470 &ff.Data,
1471 &ff.CreatedAt,
1472 &ff.ExpiresAt,
1473 )
1474 if err != nil {
1475 return nil, err
1476 }
1477
1478 ff.PaymentHistoryID = paymentHistoryID
1479
1480 return ff, nil
1481}
1482
1483func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error) {
1484 var features []*db.FeatureFlag
1485 // https://stackoverflow.com/a/16920077
1486 query := `SELECT DISTINCT ON (name)
1487 id, user_id, payment_history_id, name, data, created_at, expires_at
1488 FROM feature_flags
1489 WHERE user_id=$1
1490 ORDER BY name, expires_at DESC;`
1491 rs, err := me.Db.Query(query, userID)
1492 if err != nil {
1493 return features, err
1494 }
1495 for rs.Next() {
1496 var paymentHistoryID sql.NullString
1497 ff := &db.FeatureFlag{}
1498 err := rs.Scan(
1499 &ff.ID,
1500 &ff.UserID,
1501 &paymentHistoryID,
1502 &ff.Name,
1503 &ff.Data,
1504 &ff.CreatedAt,
1505 &ff.ExpiresAt,
1506 )
1507 if err != nil {
1508 return features, err
1509 }
1510 ff.PaymentHistoryID = paymentHistoryID
1511
1512 features = append(features, ff)
1513 }
1514 if rs.Err() != nil {
1515 return features, rs.Err()
1516 }
1517 return features, nil
1518}
1519
1520func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
1521 ff, err := me.FindFeature(userID, feature)
1522 if err != nil {
1523 return false
1524 }
1525 return ff.IsValid()
1526}
1527
1528func (me *PsqlDB) FindTotalSizeForUser(userID string) (int, error) {
1529 var fileSize int
1530 err := me.Db.QueryRow(sqlSelectSizeForUser, userID).Scan(&fileSize)
1531 if err != nil {
1532 return 0, err
1533 }
1534 return fileSize, nil
1535}
1536
1537func (me *PsqlDB) InsertFeedItems(postID string, items []*db.FeedItem) error {
1538 ctx := context.Background()
1539 tx, err := me.Db.BeginTx(ctx, nil)
1540 if err != nil {
1541 return err
1542 }
1543 defer func() {
1544 err = tx.Rollback()
1545 }()
1546
1547 for _, item := range items {
1548 _, err := tx.Exec(
1549 sqlInsertFeedItems,
1550 item.PostID,
1551 item.GUID,
1552 item.Data,
1553 )
1554 if err != nil {
1555 return err
1556 }
1557 }
1558
1559 err = tx.Commit()
1560 return err
1561}
1562
1563func (me *PsqlDB) FindFeedItemsByPostID(postID string) ([]*db.FeedItem, error) {
1564 // sqlSelectFeedItemsByPost
1565 items := make([]*db.FeedItem, 0)
1566 rs, err := me.Db.Query(sqlSelectFeedItemsByPost, postID)
1567 if err != nil {
1568 return items, err
1569 }
1570
1571 for rs.Next() {
1572 item := &db.FeedItem{}
1573 err := rs.Scan(
1574 &item.ID,
1575 &item.PostID,
1576 &item.GUID,
1577 &item.Data,
1578 &item.CreatedAt,
1579 )
1580 if err != nil {
1581 return items, err
1582 }
1583
1584 items = append(items, item)
1585 }
1586
1587 if rs.Err() != nil {
1588 return items, rs.Err()
1589 }
1590
1591 return items, nil
1592}
1593
1594func (me *PsqlDB) InsertProject(userID, name, projectDir string) (string, error) {
1595 if !utils.IsValidSubdomain(name) {
1596 return "", fmt.Errorf("'%s' is not a valid project name, must match /^[a-z0-9-]+$/", name)
1597 }
1598
1599 var id string
1600 err := me.Db.QueryRow(sqlInsertProject, userID, name, projectDir).Scan(&id)
1601 if err != nil {
1602 return "", err
1603 }
1604 return id, nil
1605}
1606
1607func (me *PsqlDB) UpdateProject(userID, name string) error {
1608 _, err := me.Db.Exec(sqlUpdateProject, userID, name, time.Now())
1609 return err
1610}
1611
1612func (me *PsqlDB) FindProjectByName(userID, name string) (*db.Project, error) {
1613 project := &db.Project{}
1614 r := me.Db.QueryRow(sqlFindProjectByName, userID, name)
1615 err := r.Scan(
1616 &project.ID,
1617 &project.UserID,
1618 &project.Name,
1619 &project.ProjectDir,
1620 &project.Acl,
1621 &project.Blocked,
1622 &project.CreatedAt,
1623 &project.UpdatedAt,
1624 )
1625 if err != nil {
1626 return nil, err
1627 }
1628
1629 return project, nil
1630}
1631
1632func (me *PsqlDB) InsertToken(userID, name string) (string, error) {
1633 var token string
1634 err := me.Db.QueryRow(sqlInsertToken, userID, name).Scan(&token)
1635 if err != nil {
1636 return "", err
1637 }
1638 return token, nil
1639}
1640
1641func (me *PsqlDB) UpsertToken(userID, name string) (string, error) {
1642 token, _ := me.FindTokenByName(userID, name)
1643 if token != "" {
1644 return token, nil
1645 }
1646
1647 token, err := me.InsertToken(userID, name)
1648 return token, err
1649}
1650
1651func (me *PsqlDB) FindTokenByName(userID, name string) (string, error) {
1652 var token string
1653 err := me.Db.QueryRow(sqlSelectTokenByNameForUser, userID, name).Scan(&token)
1654 if err != nil {
1655 return "", err
1656 }
1657 return token, nil
1658}
1659
1660func (me *PsqlDB) RemoveToken(tokenID string) error {
1661 _, err := me.Db.Exec(sqlRemoveToken, tokenID)
1662 return err
1663}
1664
1665func (me *PsqlDB) FindTokensForUser(userID string) ([]*db.Token, error) {
1666 var keys []*db.Token
1667 rs, err := me.Db.Query(sqlSelectTokensForUser, userID)
1668 if err != nil {
1669 return keys, err
1670 }
1671 for rs.Next() {
1672 pk := &db.Token{}
1673 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.CreatedAt, &pk.ExpiresAt)
1674 if err != nil {
1675 return keys, err
1676 }
1677
1678 keys = append(keys, pk)
1679 }
1680 if rs.Err() != nil {
1681 return keys, rs.Err()
1682 }
1683 return keys, nil
1684}
1685
1686func (me *PsqlDB) InsertFeature(userID, name string, expiresAt time.Time) (*db.FeatureFlag, error) {
1687 var featureID string
1688 err := me.Db.QueryRow(
1689 `INSERT INTO feature_flags (user_id, name, expires_at) VALUES ($1, $2, $3) RETURNING id;`,
1690 userID,
1691 name,
1692 expiresAt,
1693 ).Scan(&featureID)
1694 if err != nil {
1695 return nil, err
1696 }
1697
1698 feature, err := me.FindFeature(userID, name)
1699 if err != nil {
1700 return nil, err
1701 }
1702
1703 return feature, nil
1704}
1705
1706func (me *PsqlDB) RemoveFeature(userID string, name string) error {
1707 _, err := me.Db.Exec(`DELETE FROM feature_flags WHERE user_id = $1 AND name = $2`, userID, name)
1708 return err
1709}
1710
1711func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
1712 ff, _ := me.FindFeature(userID, name)
1713 if ff == nil {
1714 t := time.Now()
1715 return t.AddDate(1, 0, 0)
1716 }
1717 return ff.ExpiresAt.AddDate(1, 0, 0)
1718}
1719
1720func (me *PsqlDB) AddPicoPlusUser(username, email, paymentType, txId string) error {
1721 user, err := me.FindUserByName(username)
1722 if err != nil {
1723 return err
1724 }
1725
1726 ctx := context.Background()
1727 tx, err := me.Db.BeginTx(ctx, nil)
1728 if err != nil {
1729 return err
1730 }
1731 defer func() {
1732 err = tx.Rollback()
1733 }()
1734
1735 var paymentHistoryId sql.NullString
1736 if paymentType != "" {
1737 data := db.PaymentHistoryData{
1738 Notes: "",
1739 TxID: txId,
1740 }
1741
1742 err := tx.QueryRow(
1743 `INSERT INTO payment_history (user_id, payment_type, amount, data) VALUES ($1, $2, 24 * 1000000, $3) RETURNING id;`,
1744 user.ID,
1745 paymentType,
1746 data,
1747 ).Scan(&paymentHistoryId)
1748 if err != nil {
1749 return err
1750 }
1751 }
1752
1753 plus := me.createFeatureExpiresAt(user.ID, "plus")
1754 plusQuery := fmt.Sprintf(`INSERT INTO feature_flags (user_id, name, data, expires_at, payment_history_id)
1755 VALUES ($1, 'plus', '{"storage_max":10000000000, "file_max":50000000, "email": "%s"}'::jsonb, $2, $3);`, email)
1756 _, err = tx.Exec(plusQuery, user.ID, plus, paymentHistoryId)
1757 if err != nil {
1758 return err
1759 }
1760
1761 return tx.Commit()
1762}
1763
1764func (me *PsqlDB) UpsertProject(userID, projectName, projectDir string) (*db.Project, error) {
1765 project, err := me.FindProjectByName(userID, projectName)
1766 if err == nil {
1767 // this just updates the `createdAt` timestamp, useful for book-keeping
1768 err = me.UpdateProject(userID, projectName)
1769 if err != nil {
1770 me.Logger.Error("could not update project", "err", err)
1771 return nil, err
1772 }
1773 return project, nil
1774 }
1775
1776 _, err = me.InsertProject(userID, projectName, projectName)
1777 if err != nil {
1778 me.Logger.Error("could not create project", "err", err)
1779 return nil, err
1780 }
1781 return me.FindProjectByName(userID, projectName)
1782}
1783
1784func (me *PsqlDB) findPagesStats(userID string) (*db.UserServiceStats, error) {
1785 stats := db.UserServiceStats{
1786 Service: "pgs",
1787 }
1788 err := me.Db.QueryRow(
1789 `SELECT count(id), min(created_at), max(created_at), max(updated_at) FROM projects WHERE user_id=$1`,
1790 userID,
1791 ).Scan(&stats.Num, &stats.FirstCreatedAt, &stats.LastestCreatedAt, &stats.LatestUpdatedAt)
1792 if err != nil {
1793 return nil, err
1794 }
1795
1796 return &stats, nil
1797}
1798
1799func (me *PsqlDB) InsertTunsEventLog(log *db.TunsEventLog) error {
1800 _, err := me.Db.Exec(
1801 `INSERT INTO tuns_event_logs
1802 (user_id, server_id, remote_addr, event_type, tunnel_type, connection_type, tunnel_id)
1803 VALUES
1804 ($1, $2, $3, $4, $5, $6, $7)`,
1805 log.UserId, log.ServerID, log.RemoteAddr, log.EventType, log.TunnelType,
1806 log.ConnectionType, log.TunnelID,
1807 )
1808 return err
1809}
1810
1811func (me *PsqlDB) FindTunsEventLogsByAddr(userID, addr string) ([]*db.TunsEventLog, error) {
1812 logs := []*db.TunsEventLog{}
1813 rs, err := me.Db.Query(
1814 `SELECT id, user_id, server_id, remote_addr, event_type, tunnel_type, connection_type, tunnel_id, created_at
1815 FROM tuns_event_logs WHERE user_id=$1 AND tunnel_id=$2 ORDER BY created_at DESC`, userID, addr)
1816 if err != nil {
1817 return nil, err
1818 }
1819
1820 for rs.Next() {
1821 log := db.TunsEventLog{}
1822 err := rs.Scan(
1823 &log.ID, &log.UserId, &log.ServerID, &log.RemoteAddr,
1824 &log.EventType, &log.TunnelType, &log.ConnectionType,
1825 &log.TunnelID, &log.CreatedAt,
1826 )
1827 if err != nil {
1828 return nil, err
1829 }
1830 logs = append(logs, &log)
1831 }
1832
1833 if rs.Err() != nil {
1834 return nil, rs.Err()
1835 }
1836
1837 return logs, nil
1838}
1839
1840func (me *PsqlDB) FindTunsEventLogs(userID string) ([]*db.TunsEventLog, error) {
1841 logs := []*db.TunsEventLog{}
1842 rs, err := me.Db.Query(
1843 `SELECT id, user_id, server_id, remote_addr, event_type, tunnel_type, connection_type, tunnel_id, created_at
1844 FROM tuns_event_logs WHERE user_id=$1 ORDER BY created_at DESC`, userID)
1845 if err != nil {
1846 return nil, err
1847 }
1848
1849 for rs.Next() {
1850 log := db.TunsEventLog{}
1851 err := rs.Scan(
1852 &log.ID, &log.UserId, &log.ServerID, &log.RemoteAddr,
1853 &log.EventType, &log.TunnelType, &log.ConnectionType,
1854 &log.TunnelID, &log.CreatedAt,
1855 )
1856 if err != nil {
1857 return nil, err
1858 }
1859 logs = append(logs, &log)
1860 }
1861
1862 if rs.Err() != nil {
1863 return nil, rs.Err()
1864 }
1865
1866 return logs, nil
1867}
1868
1869func (me *PsqlDB) FindUserStats(userID string) (*db.UserStats, error) {
1870 stats := db.UserStats{}
1871 rs, err := me.Db.Query(`SELECT cur_space, count(id), min(created_at), max(created_at), max(updated_at) FROM posts WHERE user_id=$1 GROUP BY cur_space`, userID)
1872 if err != nil {
1873 return nil, err
1874 }
1875
1876 for rs.Next() {
1877 stat := db.UserServiceStats{}
1878 err := rs.Scan(&stat.Service, &stat.Num, &stat.FirstCreatedAt, &stat.LastestCreatedAt, &stat.LatestUpdatedAt)
1879 if err != nil {
1880 return nil, err
1881 }
1882 switch stat.Service {
1883 case "prose":
1884 stats.Prose = stat
1885 case "pastes":
1886 stats.Pastes = stat
1887 case "feeds":
1888 stats.Feeds = stat
1889 }
1890 }
1891
1892 if rs.Err() != nil {
1893 return nil, rs.Err()
1894 }
1895
1896 pgs, err := me.findPagesStats(userID)
1897 if err != nil {
1898 return nil, err
1899 }
1900 stats.Pages = *pgs
1901 return &stats, err
1902}