Eric Bower
·
2025-08-07
storage.go
1package postgres
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log/slog"
9 "math"
10 "strings"
11 "time"
12
13 "slices"
14
15 _ "github.com/lib/pq"
16 "github.com/picosh/pico/pkg/db"
17 "github.com/picosh/utils"
18)
19
20var PAGER_SIZE = 15
21
22var SelectPost = `
23 posts.id, user_id, app_users.name, filename, slug, title, text, description,
24 posts.created_at, publish_at, posts.updated_at, hidden, file_size, mime_type, shasum, data, expires_at, views`
25
26var (
27 sqlSelectPosts = fmt.Sprintf(`
28 SELECT %s
29 FROM posts
30 LEFT JOIN app_users ON app_users.id = posts.user_id`, SelectPost)
31
32 sqlSelectPostsBeforeDate = fmt.Sprintf(`
33 SELECT %s
34 FROM posts
35 LEFT JOIN app_users ON app_users.id = posts.user_id
36 WHERE publish_at::date <= $1 AND cur_space = $2`, SelectPost)
37
38 sqlSelectPostWithFilename = fmt.Sprintf(`
39 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
40 FROM posts
41 LEFT JOIN app_users ON app_users.id = posts.user_id
42 LEFT JOIN post_tags ON post_tags.post_id = posts.id
43 WHERE filename = $1 AND user_id = $2 AND cur_space = $3
44 GROUP BY %s`, SelectPost, SelectPost)
45
46 sqlSelectPostWithSlug = fmt.Sprintf(`
47 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
48 FROM posts
49 LEFT JOIN app_users ON app_users.id = posts.user_id
50 LEFT JOIN post_tags ON post_tags.post_id = posts.id
51 WHERE slug = $1 AND user_id = $2 AND cur_space = $3
52 GROUP BY %s`, SelectPost, SelectPost)
53
54 sqlSelectPost = fmt.Sprintf(`
55 SELECT %s
56 FROM posts
57 LEFT JOIN app_users ON app_users.id = posts.user_id
58 WHERE posts.id = $1`, SelectPost)
59
60 sqlSelectUpdatedPostsForUser = fmt.Sprintf(`
61 SELECT %s
62 FROM posts
63 LEFT JOIN app_users ON app_users.id = posts.user_id
64 WHERE user_id = $1 AND publish_at::date <= CURRENT_DATE AND cur_space = $2
65 ORDER BY posts.updated_at DESC`, SelectPost)
66
67 sqlSelectExpiredPosts = fmt.Sprintf(`
68 SELECT %s
69 FROM posts
70 LEFT JOIN app_users ON app_users.id = posts.user_id
71 WHERE
72 cur_space = $1 AND
73 expires_at <= now();
74 `, SelectPost)
75
76 sqlSelectPostsForUser = fmt.Sprintf(`
77 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
78 FROM posts
79 LEFT JOIN app_users ON app_users.id = posts.user_id
80 LEFT JOIN post_tags ON post_tags.post_id = posts.id
81 WHERE
82 hidden = FALSE AND
83 user_id = $1 AND
84 publish_at::date <= CURRENT_DATE AND
85 cur_space = $2
86 GROUP BY %s
87 ORDER BY publish_at DESC, slug DESC
88 LIMIT $3 OFFSET $4`, SelectPost, SelectPost)
89
90 sqlSelectAllPostsForUser = fmt.Sprintf(`
91 SELECT %s
92 FROM posts
93 LEFT JOIN app_users ON app_users.id = posts.user_id
94 WHERE
95 user_id = $1 AND
96 cur_space = $2
97 ORDER BY publish_at DESC`, SelectPost)
98
99 sqlSelectPostsByTag = `
100 SELECT
101 posts.id,
102 user_id,
103 filename,
104 slug,
105 title,
106 text,
107 description,
108 publish_at,
109 app_users.name as username,
110 posts.updated_at,
111 posts.mime_type
112 FROM posts
113 LEFT JOIN app_users ON app_users.id = posts.user_id
114 LEFT JOIN post_tags ON post_tags.post_id = posts.id
115 WHERE
116 post_tags.name = $3 AND
117 publish_at::date <= CURRENT_DATE AND
118 cur_space = $4
119 ORDER BY publish_at DESC
120 LIMIT $1 OFFSET $2`
121
122 sqlSelectUserPostsByTag = fmt.Sprintf(`
123 SELECT %s
124 FROM posts
125 LEFT JOIN app_users ON app_users.id = posts.user_id
126 LEFT JOIN post_tags ON post_tags.post_id = posts.id
127 WHERE
128 hidden = FALSE AND
129 user_id = $1 AND
130 (post_tags.name = $2 OR hidden = true) AND
131 publish_at::date <= CURRENT_DATE AND
132 cur_space = $3
133 ORDER BY publish_at DESC
134 LIMIT $4 OFFSET $5`, SelectPost)
135)
136
137const (
138 sqlSelectPublicKey = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE public_key = $1`
139 sqlSelectPublicKeys = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE user_id = $1 ORDER BY created_at ASC`
140 sqlSelectUser = `SELECT id, name, created_at FROM app_users WHERE id = $1`
141 sqlSelectUserForName = `SELECT id, name, created_at FROM app_users WHERE name = $1`
142 sqlSelectUserForNameAndKey = `SELECT app_users.id, app_users.name, app_users.created_at, public_keys.id as pk_id, public_keys.public_key, public_keys.created_at as pk_created_at FROM app_users LEFT JOIN public_keys ON public_keys.user_id = app_users.id WHERE app_users.name = $1 AND public_keys.public_key = $2`
143 sqlSelectUsers = `SELECT id, name, created_at FROM app_users ORDER BY name ASC`
144
145 sqlSelectUserForToken = `
146 SELECT app_users.id, app_users.name, app_users.created_at
147 FROM app_users
148 LEFT JOIN tokens ON tokens.user_id = app_users.id
149 WHERE tokens.token = $1 AND tokens.expires_at > NOW()`
150 sqlInsertToken = `INSERT INTO tokens (user_id, name) VALUES($1, $2) RETURNING token;`
151 sqlRemoveToken = `DELETE FROM tokens WHERE id = $1`
152 sqlSelectTokensForUser = `SELECT id, user_id, name, created_at, expires_at FROM tokens WHERE user_id = $1`
153 sqlSelectTokenByNameForUser = `SELECT token FROM tokens WHERE user_id = $1 AND name = $2`
154
155 sqlSelectFeatureForUser = `SELECT id, user_id, payment_history_id, name, data, created_at, expires_at FROM feature_flags WHERE user_id = $1 AND name = $2 ORDER BY expires_at DESC LIMIT 1`
156 sqlSelectSizeForUser = `SELECT COALESCE(sum(file_size), 0) FROM posts WHERE user_id = $1`
157
158 sqlSelectPostIdByAliasSlug = `SELECT post_id FROM post_aliases WHERE slug = $1`
159 sqlSelectTagPostCount = `
160 SELECT count(posts.id)
161 FROM posts
162 LEFT JOIN post_tags ON post_tags.post_id = posts.id
163 WHERE hidden = FALSE AND cur_space=$1 and post_tags.name = $2`
164 sqlSelectPostCount = `SELECT count(id) FROM posts WHERE hidden = FALSE AND cur_space=$1`
165 sqlSelectAllUpdatedPosts = `
166 SELECT
167 posts.id,
168 user_id,
169 filename,
170 slug,
171 title,
172 text,
173 description,
174 publish_at,
175 app_users.name as username,
176 posts.updated_at,
177 posts.mime_type
178 FROM posts
179 LEFT JOIN app_users ON app_users.id = posts.user_id
180 WHERE hidden = FALSE AND publish_at::date <= CURRENT_DATE AND cur_space = $3
181 ORDER BY updated_at DESC
182 LIMIT $1 OFFSET $2`
183 // add some users to deny list since they are robogenerating a bunch of posts
184 // per day and are creating a lot of noise.
185 sqlSelectPostsByRank = `
186 SELECT *
187 FROM (
188 SELECT DISTINCT ON (posts.user_id)
189 posts.id,
190 posts.user_id,
191 posts.filename,
192 posts.slug,
193 posts.title,
194 posts.text,
195 posts.description,
196 posts.publish_at,
197 app_users.name AS username,
198 posts.updated_at,
199 posts.mime_type
200 FROM posts
201 LEFT JOIN app_users ON app_users.id = posts.user_id
202 WHERE
203 hidden = FALSE
204 AND publish_at::date <= CURRENT_DATE
205 AND cur_space = $3
206 ORDER BY posts.user_id, publish_at DESC
207 ) AS latest_posts
208 ORDER BY publish_at DESC
209 LIMIT $1 OFFSET $2`
210
211 sqlSelectPopularTags = `
212 SELECT name, count(post_id) as "tally"
213 FROM post_tags
214 LEFT JOIN posts ON posts.id = post_id
215 WHERE posts.cur_space = $1
216 GROUP BY name
217 ORDER BY tally DESC
218 LIMIT 5`
219 sqlSelectTagsForUser = `
220 SELECT name
221 FROM post_tags
222 LEFT JOIN posts ON posts.id = post_id
223 WHERE posts.user_id = $1 AND posts.cur_space = $2
224 GROUP BY name`
225 sqlSelectTagsForPost = `SELECT name FROM post_tags WHERE post_id=$1`
226 sqlSelectFeedItemsByPost = `SELECT id, post_id, guid, data, created_at FROM feed_items WHERE post_id=$1`
227
228 sqlInsertPublicKey = `INSERT INTO public_keys (user_id, public_key) VALUES ($1, $2)`
229 sqlInsertPost = `
230 INSERT INTO posts
231 (user_id, filename, slug, title, text, description, publish_at, hidden, cur_space,
232 file_size, mime_type, shasum, data, expires_at, updated_at)
233 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
234 RETURNING id`
235 sqlInsertUser = `INSERT INTO app_users (name) VALUES($1) returning id`
236 sqlInsertTag = `INSERT INTO post_tags (post_id, name) VALUES($1, $2) RETURNING id;`
237 sqlInsertAliases = `INSERT INTO post_aliases (post_id, slug) VALUES($1, $2) RETURNING id;`
238 sqlInsertFeedItems = `INSERT INTO feed_items (post_id, guid, data) VALUES ($1, $2, $3) RETURNING id;`
239
240 sqlUpdatePost = `
241 UPDATE posts
242 SET slug = $1, title = $2, text = $3, description = $4, updated_at = $5, publish_at = $6,
243 file_size = $7, shasum = $8, data = $9, hidden = $11, expires_at = $12
244 WHERE id = $10`
245 sqlUpdateUserName = `UPDATE app_users SET name = $1 WHERE id = $2`
246 sqlIncrementViews = `UPDATE posts SET views = views + 1 WHERE id = $1 RETURNING views`
247
248 sqlRemoveAliasesByPost = `DELETE FROM post_aliases WHERE post_id = $1`
249 sqlRemoveTagsByPost = `DELETE FROM post_tags WHERE post_id = $1`
250 sqlRemovePosts = `DELETE FROM posts WHERE id = ANY($1::uuid[])`
251 sqlRemoveKeys = `DELETE FROM public_keys WHERE id = ANY($1::uuid[])`
252 sqlRemoveUsers = `DELETE FROM app_users WHERE id = ANY($1::uuid[])`
253
254 sqlInsertProject = `INSERT INTO projects (user_id, name, project_dir) VALUES ($1, $2, $3) RETURNING id;`
255 sqlUpdateProject = `UPDATE projects SET updated_at = $3 WHERE user_id = $1 AND name = $2;`
256 sqlFindProjectByName = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = $2;`
257 sqlSelectProjectCount = `SELECT count(id) FROM projects`
258 sqlFindProjectsByUser = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 ORDER BY name ASC, updated_at DESC;`
259 sqlFindProjectsByPrefix = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = project_dir AND name ILIKE $2 ORDER BY updated_at ASC, name ASC;`
260 sqlFindProjectLinks = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name != project_dir AND project_dir = $2 ORDER BY name ASC;`
261 sqlLinkToProject = `UPDATE projects SET project_dir = $1, updated_at = $2 WHERE id = $3;`
262 sqlRemoveProject = `DELETE FROM projects WHERE id = $1;`
263)
264
265type PsqlDB struct {
266 Logger *slog.Logger
267 Db *sql.DB
268}
269
270type RowScanner interface {
271 Scan(dest ...any) error
272}
273
274func CreatePostFromRow(r RowScanner) (*db.Post, error) {
275 post := &db.Post{}
276 err := r.Scan(
277 &post.ID,
278 &post.UserID,
279 &post.Username,
280 &post.Filename,
281 &post.Slug,
282 &post.Title,
283 &post.Text,
284 &post.Description,
285 &post.CreatedAt,
286 &post.PublishAt,
287 &post.UpdatedAt,
288 &post.Hidden,
289 &post.FileSize,
290 &post.MimeType,
291 &post.Shasum,
292 &post.Data,
293 &post.ExpiresAt,
294 &post.Views,
295 )
296 if err != nil {
297 return nil, err
298 }
299 return post, nil
300}
301
302func CreatePostWithTagsFromRow(r RowScanner) (*db.Post, error) {
303 post := &db.Post{}
304 tagStr := ""
305 err := r.Scan(
306 &post.ID,
307 &post.UserID,
308 &post.Username,
309 &post.Filename,
310 &post.Slug,
311 &post.Title,
312 &post.Text,
313 &post.Description,
314 &post.CreatedAt,
315 &post.PublishAt,
316 &post.UpdatedAt,
317 &post.Hidden,
318 &post.FileSize,
319 &post.MimeType,
320 &post.Shasum,
321 &post.Data,
322 &post.ExpiresAt,
323 &post.Views,
324 &tagStr,
325 )
326 if err != nil {
327 return nil, err
328 }
329
330 tags := strings.Split(tagStr, ",")
331 for _, tag := range tags {
332 tg := strings.TrimSpace(tag)
333 if tg == "" {
334 continue
335 }
336 post.Tags = append(post.Tags, tg)
337 }
338
339 return post, nil
340}
341
342func NewDB(databaseUrl string, logger *slog.Logger) *PsqlDB {
343 var err error
344 d := &PsqlDB{
345 Logger: logger,
346 }
347 d.Logger.Info("Connecting to postgres", "databaseUrl", databaseUrl)
348
349 db, err := sql.Open("postgres", databaseUrl)
350 if err != nil {
351 d.Logger.Error(err.Error())
352 }
353 d.Db = db
354 return d
355}
356
357func (me *PsqlDB) RegisterUser(username, pubkey, comment string) (*db.User, error) {
358 lowerName := strings.ToLower(username)
359 valid, err := me.ValidateName(lowerName)
360 if !valid {
361 return nil, err
362 }
363
364 ctx := context.Background()
365 tx, err := me.Db.BeginTx(ctx, nil)
366 if err != nil {
367 return nil, err
368 }
369 defer func() {
370 err = tx.Rollback()
371 }()
372
373 stmt, err := tx.Prepare(sqlInsertUser)
374 if err != nil {
375 return nil, err
376 }
377 defer func() {
378 _ = stmt.Close()
379 }()
380
381 var id string
382 err = stmt.QueryRow(lowerName).Scan(&id)
383 if err != nil {
384 return nil, err
385 }
386
387 err = me.InsertPublicKey(id, pubkey, comment, tx)
388 if err != nil {
389 return nil, err
390 }
391
392 err = tx.Commit()
393 if err != nil {
394 return nil, err
395 }
396
397 return me.FindUserForKey(username, pubkey)
398}
399
400func (me *PsqlDB) RemoveUsers(userIDs []string) error {
401 param := "{" + strings.Join(userIDs, ",") + "}"
402 _, err := me.Db.Exec(sqlRemoveUsers, param)
403 return err
404}
405
406func (me *PsqlDB) InsertPublicKey(userID, key, name string, tx *sql.Tx) error {
407 pk, _ := me.FindPublicKeyForKey(key)
408 if pk != nil {
409 return db.ErrPublicKeyTaken
410 }
411 query := `INSERT INTO public_keys (user_id, public_key, name) VALUES ($1, $2, $3)`
412 var err error
413 if tx != nil {
414 _, err = tx.Exec(query, userID, key, name)
415 } else {
416 _, err = me.Db.Exec(query, userID, key, name)
417 }
418 if err != nil {
419 return err
420 }
421
422 return nil
423}
424
425func (me *PsqlDB) UpdatePublicKey(pubkeyID, name string) (*db.PublicKey, error) {
426 pk, err := me.FindPublicKey(pubkeyID)
427 if err != nil {
428 return nil, err
429 }
430
431 query := `UPDATE public_keys SET name=$1 WHERE id=$2;`
432 _, err = me.Db.Exec(query, name, pk.ID)
433 if err != nil {
434 return nil, err
435 }
436
437 pk, err = me.FindPublicKey(pubkeyID)
438 if err != nil {
439 return nil, err
440 }
441 return pk, nil
442}
443
444func (me *PsqlDB) FindPublicKeyForKey(key string) (*db.PublicKey, error) {
445 var keys []*db.PublicKey
446 rs, err := me.Db.Query(sqlSelectPublicKey, key)
447 if err != nil {
448 return nil, err
449 }
450
451 for rs.Next() {
452 pk := &db.PublicKey{}
453 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
454 if err != nil {
455 return nil, err
456 }
457
458 keys = append(keys, pk)
459 }
460
461 if rs.Err() != nil {
462 return nil, rs.Err()
463 }
464
465 if len(keys) == 0 {
466 return nil, fmt.Errorf("pubkey not found in our database: [%s]", key)
467 }
468
469 // When we run PublicKeyForKey and there are multiple public keys returned from the database
470 // that should mean that we don't have the correct username for this public key.
471 // When that happens we need to reject the authentication and ask the user to provide the correct
472 // username when using ssh. So instead of `ssh <domain>` it should be `ssh user@<domain>`
473 if len(keys) > 1 {
474 return nil, &db.ErrMultiplePublicKeys{}
475 }
476
477 return keys[0], nil
478}
479
480func (me *PsqlDB) FindPublicKey(pubkeyID string) (*db.PublicKey, error) {
481 var keys []*db.PublicKey
482 rs, err := me.Db.Query(`SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE id = $1`, pubkeyID)
483 if err != nil {
484 return nil, err
485 }
486
487 for rs.Next() {
488 pk := &db.PublicKey{}
489 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
490 if err != nil {
491 return nil, err
492 }
493
494 keys = append(keys, pk)
495 }
496
497 if rs.Err() != nil {
498 return nil, rs.Err()
499 }
500
501 if len(keys) == 0 {
502 return nil, errors.New("no public keys found for key provided")
503 }
504
505 return keys[0], nil
506}
507
508func (me *PsqlDB) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
509 var keys []*db.PublicKey
510 rs, err := me.Db.Query(sqlSelectPublicKeys, user.ID)
511 if err != nil {
512 return keys, err
513 }
514 for rs.Next() {
515 pk := &db.PublicKey{}
516 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
517 if err != nil {
518 return keys, err
519 }
520
521 keys = append(keys, pk)
522 }
523 if rs.Err() != nil {
524 return keys, rs.Err()
525 }
526 return keys, nil
527}
528
529func (me *PsqlDB) RemoveKeys(keyIDs []string) error {
530 param := "{" + strings.Join(keyIDs, ",") + "}"
531 _, err := me.Db.Exec(sqlRemoveKeys, param)
532 return err
533}
534
535func (me *PsqlDB) FindPostsBeforeDate(date *time.Time, space string) ([]*db.Post, error) {
536 // now := time.Now()
537 // expired := now.AddDate(0, 0, -3)
538 var posts []*db.Post
539 rs, err := me.Db.Query(sqlSelectPostsBeforeDate, date, space)
540 if err != nil {
541 return posts, err
542 }
543 for rs.Next() {
544 post, err := CreatePostFromRow(rs)
545 if err != nil {
546 return nil, err
547 }
548
549 posts = append(posts, post)
550 }
551 if rs.Err() != nil {
552 return posts, rs.Err()
553 }
554 return posts, nil
555}
556
557func (me *PsqlDB) FindUsersWithPost(space string) ([]*db.User, error) {
558 var users []*db.User
559 rs, err := me.Db.Query(
560 `SELECT u.id, u.name, u.created_at
561 FROM app_users u
562 INNER JOIN posts ON u.id=posts.user_id
563 WHERE cur_space='feeds'
564 GROUP BY u.id, u.name, u.created_at
565 ORDER BY name ASC`,
566 )
567 if err != nil {
568 return users, err
569 }
570 for rs.Next() {
571 var name sql.NullString
572 user := &db.User{}
573 err := rs.Scan(
574 &user.ID,
575 &name,
576 &user.CreatedAt,
577 )
578 if err != nil {
579 return users, err
580 }
581 user.Name = name.String
582
583 users = append(users, user)
584 }
585 if rs.Err() != nil {
586 return users, rs.Err()
587 }
588 return users, nil
589}
590
591func (me *PsqlDB) FindUserForKey(username string, key string) (*db.User, error) {
592 me.Logger.Info("attempting to find user with only public key", "key", key)
593 pk, err := me.FindPublicKeyForKey(key)
594 if err == nil {
595 me.Logger.Info("found pubkey, looking for user", "key", key, "userId", pk.UserID)
596 user, err := me.FindUser(pk.UserID)
597 if err != nil {
598 return nil, err
599 }
600 user.PublicKey = pk
601 return user, nil
602 }
603
604 if errors.Is(err, &db.ErrMultiplePublicKeys{}) {
605 me.Logger.Info("detected multiple users with same public key", "user", username)
606 user, err := me.FindUserForNameAndKey(username, key)
607 if err != nil {
608 me.Logger.Info("could not find user by username and public key", "user", username, "key", key)
609 // this is a little hacky but if we cannot find a user by name and public key
610 // then we return the multiple keys detected error so the user knows to specify their
611 // when logging in
612 return nil, &db.ErrMultiplePublicKeys{}
613 }
614 return user, nil
615 }
616
617 return nil, err
618}
619
620func (me *PsqlDB) FindUserByPubkey(key string) (*db.User, error) {
621 me.Logger.Info("attempting to find user with only public key", "key", key)
622 pk, err := me.FindPublicKeyForKey(key)
623 if err != nil {
624 return nil, err
625 }
626
627 me.Logger.Info("found pubkey, looking for user", "key", key, "userId", pk.UserID)
628 user, err := me.FindUser(pk.UserID)
629 if err != nil {
630 return nil, err
631 }
632 user.PublicKey = pk
633 return user, nil
634}
635
636func (me *PsqlDB) FindUser(userID string) (*db.User, error) {
637 user := &db.User{}
638 var un sql.NullString
639 r := me.Db.QueryRow(sqlSelectUser, userID)
640 err := r.Scan(&user.ID, &un, &user.CreatedAt)
641 if err != nil {
642 return nil, err
643 }
644 if un.Valid {
645 user.Name = un.String
646 }
647 return user, nil
648}
649
650func (me *PsqlDB) ValidateName(name string) (bool, error) {
651 lower := strings.ToLower(name)
652 if slices.Contains(db.DenyList, lower) {
653 return false, fmt.Errorf("%s is on deny list: %w", lower, db.ErrNameDenied)
654 }
655 v := db.NameValidator.MatchString(lower)
656 if !v {
657 return false, fmt.Errorf("%s is invalid: %w", lower, db.ErrNameInvalid)
658 }
659 user, _ := me.FindUserByName(lower)
660 if user == nil {
661 return true, nil
662 }
663 return false, fmt.Errorf("%s already taken: %w", lower, db.ErrNameTaken)
664}
665
666func (me *PsqlDB) FindUserByName(name string) (*db.User, error) {
667 user := &db.User{}
668 r := me.Db.QueryRow(sqlSelectUserForName, strings.ToLower(name))
669 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
670 if err != nil {
671 return nil, err
672 }
673 return user, nil
674}
675
676func (me *PsqlDB) FindUserForNameAndKey(name string, key string) (*db.User, error) {
677 user := &db.User{}
678 pk := &db.PublicKey{}
679
680 r := me.Db.QueryRow(sqlSelectUserForNameAndKey, strings.ToLower(name), key)
681 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt, &pk.ID, &pk.Key, &pk.CreatedAt)
682 if err != nil {
683 return nil, err
684 }
685
686 user.PublicKey = pk
687 return user, nil
688}
689
690func (me *PsqlDB) FindUserForToken(token string) (*db.User, error) {
691 user := &db.User{}
692
693 r := me.Db.QueryRow(sqlSelectUserForToken, token)
694 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
695 if err != nil {
696 return nil, err
697 }
698
699 return user, nil
700}
701
702func (me *PsqlDB) SetUserName(userID string, name string) error {
703 lowerName := strings.ToLower(name)
704 valid, err := me.ValidateName(lowerName)
705 if !valid {
706 return err
707 }
708
709 _, err = me.Db.Exec(sqlUpdateUserName, lowerName, userID)
710 return err
711}
712
713func (me *PsqlDB) FindPostWithFilename(filename string, persona_id string, space string) (*db.Post, error) {
714 r := me.Db.QueryRow(sqlSelectPostWithFilename, filename, persona_id, space)
715 post, err := CreatePostWithTagsFromRow(r)
716 if err != nil {
717 return nil, err
718 }
719
720 return post, nil
721}
722
723func (me *PsqlDB) FindPostWithSlug(slug string, user_id string, space string) (*db.Post, error) {
724 r := me.Db.QueryRow(sqlSelectPostWithSlug, slug, user_id, space)
725 post, err := CreatePostWithTagsFromRow(r)
726 if err != nil {
727 // attempt to find post inside post_aliases
728 alias := me.Db.QueryRow(sqlSelectPostIdByAliasSlug, slug)
729 postID := ""
730 err := alias.Scan(&postID)
731 if err != nil {
732 return nil, err
733 }
734
735 return me.FindPost(postID)
736 }
737
738 return post, nil
739}
740
741func (me *PsqlDB) FindPost(postID string) (*db.Post, error) {
742 r := me.Db.QueryRow(sqlSelectPost, postID)
743 post, err := CreatePostFromRow(r)
744 if err != nil {
745 return nil, err
746 }
747
748 return post, nil
749}
750
751func (me *PsqlDB) postPager(rs *sql.Rows, pageNum int, space string, tag string) (*db.Paginate[*db.Post], error) {
752 var posts []*db.Post
753 for rs.Next() {
754 post := &db.Post{}
755 err := rs.Scan(
756 &post.ID,
757 &post.UserID,
758 &post.Filename,
759 &post.Slug,
760 &post.Title,
761 &post.Text,
762 &post.Description,
763 &post.PublishAt,
764 &post.Username,
765 &post.UpdatedAt,
766 &post.MimeType,
767 )
768 if err != nil {
769 return nil, err
770 }
771
772 posts = append(posts, post)
773 }
774 if rs.Err() != nil {
775 return nil, rs.Err()
776 }
777
778 var count int
779 var err error
780 if tag == "" {
781 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
782 } else {
783 err = me.Db.QueryRow(sqlSelectTagPostCount, space, tag).Scan(&count)
784 }
785 if err != nil {
786 return nil, err
787 }
788
789 pager := &db.Paginate[*db.Post]{
790 Data: posts,
791 Total: int(math.Ceil(float64(count) / float64(pageNum))),
792 }
793
794 return pager, nil
795}
796
797func (me *PsqlDB) FindPostsForFeed(page *db.Pager, space string) (*db.Paginate[*db.Post], error) {
798 rs, err := me.Db.Query(sqlSelectPostsByRank, page.Num, page.Num*page.Page, space)
799 if err != nil {
800 return nil, err
801 }
802 return me.postPager(rs, page.Num, space, "")
803}
804
805func (me *PsqlDB) FindAllUpdatedPosts(page *db.Pager, space string) (*db.Paginate[*db.Post], error) {
806 rs, err := me.Db.Query(sqlSelectAllUpdatedPosts, page.Num, page.Num*page.Page, space)
807 if err != nil {
808 return nil, err
809 }
810 return me.postPager(rs, page.Num, space, "")
811}
812
813func (me *PsqlDB) InsertPost(post *db.Post) (*db.Post, error) {
814 var id string
815 err := me.Db.QueryRow(
816 sqlInsertPost,
817 post.UserID,
818 post.Filename,
819 post.Slug,
820 post.Title,
821 post.Text,
822 post.Description,
823 post.PublishAt,
824 post.Hidden,
825 post.Space,
826 post.FileSize,
827 post.MimeType,
828 post.Shasum,
829 post.Data,
830 post.ExpiresAt,
831 post.UpdatedAt,
832 ).Scan(&id)
833 if err != nil {
834 return nil, err
835 }
836
837 return me.FindPost(id)
838}
839
840func (me *PsqlDB) UpdatePost(post *db.Post) (*db.Post, error) {
841 _, err := me.Db.Exec(
842 sqlUpdatePost,
843 post.Slug,
844 post.Title,
845 post.Text,
846 post.Description,
847 post.UpdatedAt,
848 post.PublishAt,
849 post.FileSize,
850 post.Shasum,
851 post.Data,
852 post.ID,
853 post.Hidden,
854 post.ExpiresAt,
855 )
856 if err != nil {
857 return nil, err
858 }
859
860 return me.FindPost(post.ID)
861}
862
863func (me *PsqlDB) RemovePosts(postIDs []string) error {
864 param := "{" + strings.Join(postIDs, ",") + "}"
865 _, err := me.Db.Exec(sqlRemovePosts, param)
866 return err
867}
868
869func (me *PsqlDB) FindPostsForUser(page *db.Pager, userID string, space string) (*db.Paginate[*db.Post], error) {
870 var posts []*db.Post
871 rs, err := me.Db.Query(
872 sqlSelectPostsForUser,
873 userID,
874 space,
875 page.Num,
876 page.Num*page.Page,
877 )
878 if err != nil {
879 return nil, err
880 }
881 for rs.Next() {
882 post, err := CreatePostWithTagsFromRow(rs)
883 if err != nil {
884 return nil, err
885 }
886
887 posts = append(posts, post)
888 }
889
890 if rs.Err() != nil {
891 return nil, rs.Err()
892 }
893
894 var count int
895 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
896 if err != nil {
897 return nil, err
898 }
899
900 pager := &db.Paginate[*db.Post]{
901 Data: posts,
902 Total: int(math.Ceil(float64(count) / float64(page.Num))),
903 }
904 return pager, nil
905}
906
907func (me *PsqlDB) FindAllPostsForUser(userID string, space string) ([]*db.Post, error) {
908 var posts []*db.Post
909 rs, err := me.Db.Query(sqlSelectAllPostsForUser, userID, space)
910 if err != nil {
911 return posts, err
912 }
913 for rs.Next() {
914 post, err := CreatePostFromRow(rs)
915 if err != nil {
916 return nil, err
917 }
918
919 posts = append(posts, post)
920 }
921 if rs.Err() != nil {
922 return posts, rs.Err()
923 }
924 return posts, nil
925}
926
927func (me *PsqlDB) FindPosts() ([]*db.Post, error) {
928 var posts []*db.Post
929 rs, err := me.Db.Query(sqlSelectPosts)
930 if err != nil {
931 return posts, err
932 }
933 for rs.Next() {
934 post, err := CreatePostFromRow(rs)
935 if err != nil {
936 return nil, err
937 }
938
939 posts = append(posts, post)
940 }
941 if rs.Err() != nil {
942 return posts, rs.Err()
943 }
944 return posts, nil
945}
946
947func (me *PsqlDB) FindExpiredPosts(space string) ([]*db.Post, error) {
948 var posts []*db.Post
949 rs, err := me.Db.Query(sqlSelectExpiredPosts, space)
950 if err != nil {
951 return posts, err
952 }
953 for rs.Next() {
954 post, err := CreatePostFromRow(rs)
955 if err != nil {
956 return nil, err
957 }
958
959 posts = append(posts, post)
960 }
961 if rs.Err() != nil {
962 return posts, rs.Err()
963 }
964 return posts, nil
965}
966
967func (me *PsqlDB) FindUpdatedPostsForUser(userID string, space string) ([]*db.Post, error) {
968 var posts []*db.Post
969 rs, err := me.Db.Query(sqlSelectUpdatedPostsForUser, userID, space)
970 if err != nil {
971 return posts, err
972 }
973 for rs.Next() {
974 post, err := CreatePostFromRow(rs)
975 if err != nil {
976 return nil, err
977 }
978
979 posts = append(posts, post)
980 }
981 if rs.Err() != nil {
982 return posts, rs.Err()
983 }
984 return posts, nil
985}
986
987func (me *PsqlDB) Close() error {
988 me.Logger.Info("Closing db")
989 return me.Db.Close()
990}
991
992func newNullString(s string) sql.NullString {
993 if len(s) == 0 {
994 return sql.NullString{}
995 }
996 return sql.NullString{
997 String: s,
998 Valid: true,
999 }
1000}
1001
1002func (me *PsqlDB) InsertVisit(visit *db.AnalyticsVisits) error {
1003 _, err := me.Db.Exec(
1004 `INSERT INTO analytics_visits (user_id, project_id, post_id, namespace, host, path, ip_address, user_agent, referer, status, content_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);`,
1005 visit.UserID,
1006 newNullString(visit.ProjectID),
1007 newNullString(visit.PostID),
1008 newNullString(visit.Namespace),
1009 visit.Host,
1010 visit.Path,
1011 visit.IpAddress,
1012 visit.UserAgent,
1013 visit.Referer,
1014 visit.Status,
1015 visit.ContentType,
1016 )
1017 return err
1018}
1019
1020func visitFilterBy(opts *db.SummaryOpts) (string, string) {
1021 where := ""
1022 val := ""
1023 if opts.Host != "" {
1024 where = "host"
1025 val = opts.Host
1026 } else if opts.Path != "" {
1027 where = "path"
1028 val = opts.Path
1029 }
1030
1031 return where, val
1032}
1033
1034func (me *PsqlDB) visitUnique(opts *db.SummaryOpts) ([]*db.VisitInterval, error) {
1035 where, with := visitFilterBy(opts)
1036 uniqueVisitors := fmt.Sprintf(`SELECT
1037 date_trunc('%s', created_at) as interval_start,
1038 count(DISTINCT ip_address) as unique_visitors
1039 FROM analytics_visits
1040 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND status <> 404
1041 GROUP BY interval_start`, opts.Interval, where)
1042
1043 intervals := []*db.VisitInterval{}
1044 rs, err := me.Db.Query(uniqueVisitors, opts.Origin, with, opts.UserID)
1045 if err != nil {
1046 return nil, err
1047 }
1048
1049 for rs.Next() {
1050 interval := &db.VisitInterval{}
1051 err := rs.Scan(
1052 &interval.Interval,
1053 &interval.Visitors,
1054 )
1055 if err != nil {
1056 return nil, err
1057 }
1058
1059 intervals = append(intervals, interval)
1060 }
1061 if rs.Err() != nil {
1062 return nil, rs.Err()
1063 }
1064 return intervals, nil
1065}
1066
1067func (me *PsqlDB) visitReferer(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1068 where, with := visitFilterBy(opts)
1069 topUrls := fmt.Sprintf(`SELECT
1070 referer,
1071 count(DISTINCT ip_address) as referer_count
1072 FROM analytics_visits
1073 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND referer <> '' AND status <> 404
1074 GROUP BY referer
1075 ORDER BY referer_count DESC
1076 LIMIT 10`, where)
1077
1078 intervals := []*db.VisitUrl{}
1079 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1080 if err != nil {
1081 return nil, err
1082 }
1083
1084 for rs.Next() {
1085 interval := &db.VisitUrl{}
1086 err := rs.Scan(
1087 &interval.Url,
1088 &interval.Count,
1089 )
1090 if err != nil {
1091 return nil, err
1092 }
1093
1094 intervals = append(intervals, interval)
1095 }
1096 if rs.Err() != nil {
1097 return nil, rs.Err()
1098 }
1099 return intervals, nil
1100}
1101
1102func (me *PsqlDB) visitUrl(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1103 where, with := visitFilterBy(opts)
1104 topUrls := fmt.Sprintf(`SELECT
1105 path,
1106 count(DISTINCT ip_address) as path_count
1107 FROM analytics_visits
1108 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND path <> '' AND status <> 404
1109 GROUP BY path
1110 ORDER BY path_count DESC
1111 LIMIT 10`, where)
1112
1113 intervals := []*db.VisitUrl{}
1114 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1115 if err != nil {
1116 return nil, err
1117 }
1118
1119 for rs.Next() {
1120 interval := &db.VisitUrl{}
1121 err := rs.Scan(
1122 &interval.Url,
1123 &interval.Count,
1124 )
1125 if err != nil {
1126 return nil, err
1127 }
1128
1129 intervals = append(intervals, interval)
1130 }
1131 if rs.Err() != nil {
1132 return nil, rs.Err()
1133 }
1134 return intervals, nil
1135}
1136
1137func (me *PsqlDB) VisitUrlNotFound(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1138 limit := opts.Limit
1139 if limit == 0 {
1140 limit = 10
1141 }
1142 where, with := visitFilterBy(opts)
1143 topUrls := fmt.Sprintf(`SELECT
1144 path,
1145 count(DISTINCT ip_address) as path_count
1146 FROM analytics_visits
1147 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND path <> '' AND status = 404
1148 GROUP BY path
1149 ORDER BY path_count DESC
1150 LIMIT %d`, where, limit)
1151
1152 intervals := []*db.VisitUrl{}
1153 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1154 if err != nil {
1155 return nil, err
1156 }
1157
1158 for rs.Next() {
1159 interval := &db.VisitUrl{}
1160 err := rs.Scan(
1161 &interval.Url,
1162 &interval.Count,
1163 )
1164 if err != nil {
1165 return nil, err
1166 }
1167
1168 intervals = append(intervals, interval)
1169 }
1170 if rs.Err() != nil {
1171 return nil, rs.Err()
1172 }
1173 return intervals, nil
1174}
1175
1176func (me *PsqlDB) visitHost(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1177 topUrls := `SELECT
1178 host,
1179 count(DISTINCT ip_address) as host_count
1180 FROM analytics_visits
1181 WHERE user_id = $1 AND host <> ''
1182 GROUP BY host
1183 ORDER BY host_count DESC`
1184
1185 intervals := []*db.VisitUrl{}
1186 rs, err := me.Db.Query(topUrls, opts.UserID)
1187 if err != nil {
1188 return nil, err
1189 }
1190
1191 for rs.Next() {
1192 interval := &db.VisitUrl{}
1193 err := rs.Scan(
1194 &interval.Url,
1195 &interval.Count,
1196 )
1197 if err != nil {
1198 return nil, err
1199 }
1200
1201 intervals = append(intervals, interval)
1202 }
1203 if rs.Err() != nil {
1204 return nil, rs.Err()
1205 }
1206 return intervals, nil
1207}
1208
1209func (me *PsqlDB) VisitSummary(opts *db.SummaryOpts) (*db.SummaryVisits, error) {
1210 visitors, err := me.visitUnique(opts)
1211 if err != nil {
1212 return nil, err
1213 }
1214
1215 urls, err := me.visitUrl(opts)
1216 if err != nil {
1217 return nil, err
1218 }
1219
1220 notFound, err := me.VisitUrlNotFound(opts)
1221 if err != nil {
1222 return nil, err
1223 }
1224
1225 refs, err := me.visitReferer(opts)
1226 if err != nil {
1227 return nil, err
1228 }
1229
1230 return &db.SummaryVisits{
1231 Intervals: visitors,
1232 TopUrls: urls,
1233 TopReferers: refs,
1234 NotFoundUrls: notFound,
1235 }, nil
1236}
1237
1238func (me *PsqlDB) FindVisitSiteList(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1239 return me.visitHost(opts)
1240}
1241
1242func (me *PsqlDB) FindUsers() ([]*db.User, error) {
1243 var users []*db.User
1244 rs, err := me.Db.Query(sqlSelectUsers)
1245 if err != nil {
1246 return users, err
1247 }
1248 for rs.Next() {
1249 var name sql.NullString
1250 user := &db.User{}
1251 err := rs.Scan(
1252 &user.ID,
1253 &name,
1254 &user.CreatedAt,
1255 )
1256 if err != nil {
1257 return users, err
1258 }
1259 user.Name = name.String
1260
1261 users = append(users, user)
1262 }
1263 if rs.Err() != nil {
1264 return users, rs.Err()
1265 }
1266 return users, nil
1267}
1268
1269func (me *PsqlDB) removeTagsForPost(tx *sql.Tx, postID string) error {
1270 _, err := tx.Exec(sqlRemoveTagsByPost, postID)
1271 return err
1272}
1273
1274func (me *PsqlDB) insertTagsForPost(tx *sql.Tx, tags []string, postID string) ([]string, error) {
1275 ids := make([]string, 0)
1276 for _, tag := range tags {
1277 id := ""
1278 err := tx.QueryRow(sqlInsertTag, postID, tag).Scan(&id)
1279 if err != nil {
1280 return nil, err
1281 }
1282 ids = append(ids, id)
1283 }
1284
1285 return ids, nil
1286}
1287
1288func (me *PsqlDB) ReplaceTagsForPost(tags []string, postID string) error {
1289 ctx := context.Background()
1290 tx, err := me.Db.BeginTx(ctx, nil)
1291 if err != nil {
1292 return err
1293 }
1294 defer func() {
1295 err = tx.Rollback()
1296 }()
1297
1298 err = me.removeTagsForPost(tx, postID)
1299 if err != nil {
1300 return err
1301 }
1302
1303 _, err = me.insertTagsForPost(tx, tags, postID)
1304 if err != nil {
1305 return err
1306 }
1307
1308 err = tx.Commit()
1309 return err
1310}
1311
1312func (me *PsqlDB) removeAliasesForPost(tx *sql.Tx, postID string) error {
1313 _, err := tx.Exec(sqlRemoveAliasesByPost, postID)
1314 return err
1315}
1316
1317func (me *PsqlDB) insertAliasesForPost(tx *sql.Tx, aliases []string, postID string) ([]string, error) {
1318 // hardcoded
1319 denyList := []string{
1320 "rss",
1321 "rss.xml",
1322 "rss.atom",
1323 "atom.xml",
1324 "feed.xml",
1325 "smol.css",
1326 "main.css",
1327 "syntax.css",
1328 "card.png",
1329 "favicon-16x16.png",
1330 "favicon-32x32.png",
1331 "apple-touch-icon.png",
1332 "favicon.ico",
1333 "robots.txt",
1334 "atom",
1335 "blog/index.xml",
1336 }
1337
1338 ids := make([]string, 0)
1339 for _, alias := range aliases {
1340 if slices.Contains(denyList, alias) {
1341 me.Logger.Info(
1342 "name is in the deny list for aliases because it conflicts with a static route, skipping",
1343 "alias", alias,
1344 )
1345 continue
1346 }
1347 id := ""
1348 err := tx.QueryRow(sqlInsertAliases, postID, alias).Scan(&id)
1349 if err != nil {
1350 return nil, err
1351 }
1352 ids = append(ids, id)
1353 }
1354
1355 return ids, nil
1356}
1357
1358func (me *PsqlDB) ReplaceAliasesForPost(aliases []string, postID string) error {
1359 ctx := context.Background()
1360 tx, err := me.Db.BeginTx(ctx, nil)
1361 if err != nil {
1362 return err
1363 }
1364 defer func() {
1365 err = tx.Rollback()
1366 }()
1367
1368 err = me.removeAliasesForPost(tx, postID)
1369 if err != nil {
1370 return err
1371 }
1372
1373 _, err = me.insertAliasesForPost(tx, aliases, postID)
1374 if err != nil {
1375 return err
1376 }
1377
1378 err = tx.Commit()
1379 return err
1380}
1381
1382func (me *PsqlDB) FindUserPostsByTag(page *db.Pager, tag, userID, space string) (*db.Paginate[*db.Post], error) {
1383 var posts []*db.Post
1384 rs, err := me.Db.Query(
1385 sqlSelectUserPostsByTag,
1386 userID,
1387 tag,
1388 space,
1389 page.Num,
1390 page.Num*page.Page,
1391 )
1392 if err != nil {
1393 return nil, err
1394 }
1395 for rs.Next() {
1396 post, err := CreatePostFromRow(rs)
1397 if err != nil {
1398 return nil, err
1399 }
1400
1401 posts = append(posts, post)
1402 }
1403
1404 if rs.Err() != nil {
1405 return nil, rs.Err()
1406 }
1407
1408 var count int
1409 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
1410 if err != nil {
1411 return nil, err
1412 }
1413
1414 pager := &db.Paginate[*db.Post]{
1415 Data: posts,
1416 Total: int(math.Ceil(float64(count) / float64(page.Num))),
1417 }
1418 return pager, nil
1419}
1420
1421func (me *PsqlDB) FindPostsByTag(pager *db.Pager, tag, space string) (*db.Paginate[*db.Post], error) {
1422 rs, err := me.Db.Query(
1423 sqlSelectPostsByTag,
1424 pager.Num,
1425 pager.Num*pager.Page,
1426 tag,
1427 space,
1428 )
1429 if err != nil {
1430 return nil, err
1431 }
1432
1433 return me.postPager(rs, pager.Num, space, tag)
1434}
1435
1436func (me *PsqlDB) FindPopularTags(space string) ([]string, error) {
1437 tags := make([]string, 0)
1438 rs, err := me.Db.Query(sqlSelectPopularTags, space)
1439 if err != nil {
1440 return tags, err
1441 }
1442 for rs.Next() {
1443 name := ""
1444 tally := 0
1445 err := rs.Scan(&name, &tally)
1446 if err != nil {
1447 return tags, err
1448 }
1449
1450 tags = append(tags, name)
1451 }
1452 if rs.Err() != nil {
1453 return tags, rs.Err()
1454 }
1455 return tags, nil
1456}
1457
1458func (me *PsqlDB) FindTagsForUser(userID string, space string) ([]string, error) {
1459 tags := []string{}
1460 rs, err := me.Db.Query(sqlSelectTagsForUser, userID, space)
1461 if err != nil {
1462 return tags, err
1463 }
1464 for rs.Next() {
1465 name := ""
1466 err := rs.Scan(&name)
1467 if err != nil {
1468 return tags, err
1469 }
1470
1471 tags = append(tags, name)
1472 }
1473 if rs.Err() != nil {
1474 return tags, rs.Err()
1475 }
1476 return tags, nil
1477}
1478
1479func (me *PsqlDB) FindTagsForPost(postID string) ([]string, error) {
1480 tags := make([]string, 0)
1481 rs, err := me.Db.Query(sqlSelectTagsForPost, postID)
1482 if err != nil {
1483 return tags, err
1484 }
1485
1486 for rs.Next() {
1487 name := ""
1488 err := rs.Scan(&name)
1489 if err != nil {
1490 return tags, err
1491 }
1492
1493 tags = append(tags, name)
1494 }
1495
1496 if rs.Err() != nil {
1497 return tags, rs.Err()
1498 }
1499
1500 return tags, nil
1501}
1502
1503func (me *PsqlDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
1504 ff := &db.FeatureFlag{}
1505 // payment history is allowed to be null
1506 // https://devtidbits.com/2020/08/03/go-sql-error-converting-null-to-string-is-unsupported/
1507 var paymentHistoryID sql.NullString
1508 err := me.Db.QueryRow(sqlSelectFeatureForUser, userID, feature).Scan(
1509 &ff.ID,
1510 &ff.UserID,
1511 &paymentHistoryID,
1512 &ff.Name,
1513 &ff.Data,
1514 &ff.CreatedAt,
1515 &ff.ExpiresAt,
1516 )
1517 if err != nil {
1518 return nil, err
1519 }
1520
1521 ff.PaymentHistoryID = paymentHistoryID
1522
1523 return ff, nil
1524}
1525
1526func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error) {
1527 var features []*db.FeatureFlag
1528 // https://stackoverflow.com/a/16920077
1529 query := `SELECT DISTINCT ON (name)
1530 id, user_id, payment_history_id, name, data, created_at, expires_at
1531 FROM feature_flags
1532 WHERE user_id=$1
1533 ORDER BY name, expires_at DESC;`
1534 rs, err := me.Db.Query(query, userID)
1535 if err != nil {
1536 return features, err
1537 }
1538 for rs.Next() {
1539 var paymentHistoryID sql.NullString
1540 ff := &db.FeatureFlag{}
1541 err := rs.Scan(
1542 &ff.ID,
1543 &ff.UserID,
1544 &paymentHistoryID,
1545 &ff.Name,
1546 &ff.Data,
1547 &ff.CreatedAt,
1548 &ff.ExpiresAt,
1549 )
1550 if err != nil {
1551 return features, err
1552 }
1553 ff.PaymentHistoryID = paymentHistoryID
1554
1555 features = append(features, ff)
1556 }
1557 if rs.Err() != nil {
1558 return features, rs.Err()
1559 }
1560 return features, nil
1561}
1562
1563func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
1564 ff, err := me.FindFeature(userID, feature)
1565 if err != nil {
1566 return false
1567 }
1568 return ff.IsValid()
1569}
1570
1571func (me *PsqlDB) FindTotalSizeForUser(userID string) (int, error) {
1572 var fileSize int
1573 err := me.Db.QueryRow(sqlSelectSizeForUser, userID).Scan(&fileSize)
1574 if err != nil {
1575 return 0, err
1576 }
1577 return fileSize, nil
1578}
1579
1580func (me *PsqlDB) InsertFeedItems(postID string, items []*db.FeedItem) error {
1581 ctx := context.Background()
1582 tx, err := me.Db.BeginTx(ctx, nil)
1583 if err != nil {
1584 return err
1585 }
1586 defer func() {
1587 err = tx.Rollback()
1588 }()
1589
1590 for _, item := range items {
1591 _, err := tx.Exec(
1592 sqlInsertFeedItems,
1593 item.PostID,
1594 item.GUID,
1595 item.Data,
1596 )
1597 if err != nil {
1598 return fmt.Errorf(
1599 "post id:%s, link:%s, guid:%s, err:%w",
1600 item.PostID, item.Data.Link, item.GUID, err,
1601 )
1602 }
1603 }
1604
1605 err = tx.Commit()
1606 return err
1607}
1608
1609func (me *PsqlDB) FindFeedItemsByPostID(postID string) ([]*db.FeedItem, error) {
1610 // sqlSelectFeedItemsByPost
1611 items := make([]*db.FeedItem, 0)
1612 rs, err := me.Db.Query(sqlSelectFeedItemsByPost, postID)
1613 if err != nil {
1614 return items, err
1615 }
1616
1617 for rs.Next() {
1618 item := &db.FeedItem{}
1619 err := rs.Scan(
1620 &item.ID,
1621 &item.PostID,
1622 &item.GUID,
1623 &item.Data,
1624 &item.CreatedAt,
1625 )
1626 if err != nil {
1627 return items, err
1628 }
1629
1630 items = append(items, item)
1631 }
1632
1633 if rs.Err() != nil {
1634 return items, rs.Err()
1635 }
1636
1637 return items, nil
1638}
1639
1640func (me *PsqlDB) InsertProject(userID, name, projectDir string) (string, error) {
1641 if !utils.IsValidSubdomain(name) {
1642 return "", fmt.Errorf("'%s' is not a valid project name, must match /^[a-z0-9-]+$/", name)
1643 }
1644
1645 var id string
1646 err := me.Db.QueryRow(sqlInsertProject, userID, name, projectDir).Scan(&id)
1647 if err != nil {
1648 return "", err
1649 }
1650 return id, nil
1651}
1652
1653func (me *PsqlDB) UpdateProject(userID, name string) error {
1654 _, err := me.Db.Exec(sqlUpdateProject, userID, name, time.Now())
1655 return err
1656}
1657
1658func (me *PsqlDB) FindProjectByName(userID, name string) (*db.Project, error) {
1659 project := &db.Project{}
1660 r := me.Db.QueryRow(sqlFindProjectByName, userID, name)
1661 err := r.Scan(
1662 &project.ID,
1663 &project.UserID,
1664 &project.Name,
1665 &project.ProjectDir,
1666 &project.Acl,
1667 &project.Blocked,
1668 &project.CreatedAt,
1669 &project.UpdatedAt,
1670 )
1671 if err != nil {
1672 return nil, err
1673 }
1674
1675 return project, nil
1676}
1677
1678func (me *PsqlDB) InsertToken(userID, name string) (string, error) {
1679 var token string
1680 err := me.Db.QueryRow(sqlInsertToken, userID, name).Scan(&token)
1681 if err != nil {
1682 return "", err
1683 }
1684 return token, nil
1685}
1686
1687func (me *PsqlDB) UpsertToken(userID, name string) (string, error) {
1688 token, _ := me.FindTokenByName(userID, name)
1689 if token != "" {
1690 return token, nil
1691 }
1692
1693 token, err := me.InsertToken(userID, name)
1694 return token, err
1695}
1696
1697func (me *PsqlDB) FindTokenByName(userID, name string) (string, error) {
1698 var token string
1699 err := me.Db.QueryRow(sqlSelectTokenByNameForUser, userID, name).Scan(&token)
1700 if err != nil {
1701 return "", err
1702 }
1703 return token, nil
1704}
1705
1706func (me *PsqlDB) RemoveToken(tokenID string) error {
1707 _, err := me.Db.Exec(sqlRemoveToken, tokenID)
1708 return err
1709}
1710
1711func (me *PsqlDB) FindTokensForUser(userID string) ([]*db.Token, error) {
1712 var keys []*db.Token
1713 rs, err := me.Db.Query(sqlSelectTokensForUser, userID)
1714 if err != nil {
1715 return keys, err
1716 }
1717 for rs.Next() {
1718 pk := &db.Token{}
1719 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.CreatedAt, &pk.ExpiresAt)
1720 if err != nil {
1721 return keys, err
1722 }
1723
1724 keys = append(keys, pk)
1725 }
1726 if rs.Err() != nil {
1727 return keys, rs.Err()
1728 }
1729 return keys, nil
1730}
1731
1732func (me *PsqlDB) InsertFeature(userID, name string, expiresAt time.Time) (*db.FeatureFlag, error) {
1733 var featureID string
1734 err := me.Db.QueryRow(
1735 `INSERT INTO feature_flags (user_id, name, expires_at) VALUES ($1, $2, $3) RETURNING id;`,
1736 userID,
1737 name,
1738 expiresAt,
1739 ).Scan(&featureID)
1740 if err != nil {
1741 return nil, err
1742 }
1743
1744 feature, err := me.FindFeature(userID, name)
1745 if err != nil {
1746 return nil, err
1747 }
1748
1749 return feature, nil
1750}
1751
1752func (me *PsqlDB) RemoveFeature(userID string, name string) error {
1753 _, err := me.Db.Exec(`DELETE FROM feature_flags WHERE user_id = $1 AND name = $2`, userID, name)
1754 return err
1755}
1756
1757func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
1758 ff, _ := me.FindFeature(userID, name)
1759 if ff == nil {
1760 t := time.Now()
1761 return t.AddDate(1, 0, 0)
1762 }
1763 return ff.ExpiresAt.AddDate(1, 0, 0)
1764}
1765
1766func (me *PsqlDB) AddPicoPlusUser(username, email, paymentType, txId string) error {
1767 user, err := me.FindUserByName(username)
1768 if err != nil {
1769 return err
1770 }
1771
1772 ctx := context.Background()
1773 tx, err := me.Db.BeginTx(ctx, nil)
1774 if err != nil {
1775 return err
1776 }
1777 defer func() {
1778 err = tx.Rollback()
1779 }()
1780
1781 var paymentHistoryId sql.NullString
1782 if paymentType != "" {
1783 data := db.PaymentHistoryData{
1784 Notes: "",
1785 TxID: txId,
1786 }
1787
1788 err := tx.QueryRow(
1789 `INSERT INTO payment_history (user_id, payment_type, amount, data) VALUES ($1, $2, 24 * 1000000, $3) RETURNING id;`,
1790 user.ID,
1791 paymentType,
1792 data,
1793 ).Scan(&paymentHistoryId)
1794 if err != nil {
1795 return err
1796 }
1797 }
1798
1799 plus := me.createFeatureExpiresAt(user.ID, "plus")
1800 plusQuery := fmt.Sprintf(`INSERT INTO feature_flags (user_id, name, data, expires_at, payment_history_id)
1801 VALUES ($1, 'plus', '{"storage_max":10000000000, "file_max":50000000, "email": "%s"}'::jsonb, $2, $3);`, email)
1802 _, err = tx.Exec(plusQuery, user.ID, plus, paymentHistoryId)
1803 if err != nil {
1804 return err
1805 }
1806
1807 return tx.Commit()
1808}
1809
1810func (me *PsqlDB) UpsertProject(userID, projectName, projectDir string) (*db.Project, error) {
1811 project, err := me.FindProjectByName(userID, projectName)
1812 if err == nil {
1813 // this just updates the `createdAt` timestamp, useful for book-keeping
1814 err = me.UpdateProject(userID, projectName)
1815 if err != nil {
1816 me.Logger.Error("could not update project", "err", err)
1817 return nil, err
1818 }
1819 return project, nil
1820 }
1821
1822 _, err = me.InsertProject(userID, projectName, projectName)
1823 if err != nil {
1824 me.Logger.Error("could not create project", "err", err)
1825 return nil, err
1826 }
1827 return me.FindProjectByName(userID, projectName)
1828}
1829
1830func (me *PsqlDB) findPagesStats(userID string) (*db.UserServiceStats, error) {
1831 stats := db.UserServiceStats{
1832 Service: "pgs",
1833 }
1834 err := me.Db.QueryRow(
1835 `SELECT count(id), min(created_at), max(created_at), max(updated_at) FROM projects WHERE user_id=$1`,
1836 userID,
1837 ).Scan(&stats.Num, &stats.FirstCreatedAt, &stats.LastestCreatedAt, &stats.LatestUpdatedAt)
1838 if err != nil {
1839 return nil, err
1840 }
1841
1842 return &stats, nil
1843}
1844
1845func (me *PsqlDB) InsertTunsEventLog(log *db.TunsEventLog) error {
1846 _, err := me.Db.Exec(
1847 `INSERT INTO tuns_event_logs
1848 (user_id, server_id, remote_addr, event_type, tunnel_type, connection_type, tunnel_id)
1849 VALUES
1850 ($1, $2, $3, $4, $5, $6, $7)`,
1851 log.UserId, log.ServerID, log.RemoteAddr, log.EventType, log.TunnelType,
1852 log.ConnectionType, log.TunnelID,
1853 )
1854 return err
1855}
1856
1857func (me *PsqlDB) FindTunsEventLogsByAddr(userID, addr string) ([]*db.TunsEventLog, error) {
1858 logs := []*db.TunsEventLog{}
1859 rs, err := me.Db.Query(
1860 `SELECT id, user_id, server_id, remote_addr, event_type, tunnel_type, connection_type, tunnel_id, created_at
1861 FROM tuns_event_logs WHERE user_id=$1 AND tunnel_id=$2 ORDER BY created_at DESC`, userID, addr)
1862 if err != nil {
1863 return nil, err
1864 }
1865
1866 for rs.Next() {
1867 log := db.TunsEventLog{}
1868 err := rs.Scan(
1869 &log.ID, &log.UserId, &log.ServerID, &log.RemoteAddr,
1870 &log.EventType, &log.TunnelType, &log.ConnectionType,
1871 &log.TunnelID, &log.CreatedAt,
1872 )
1873 if err != nil {
1874 return nil, err
1875 }
1876 logs = append(logs, &log)
1877 }
1878
1879 if rs.Err() != nil {
1880 return nil, rs.Err()
1881 }
1882
1883 return logs, nil
1884}
1885
1886func (me *PsqlDB) FindTunsEventLogs(userID string) ([]*db.TunsEventLog, error) {
1887 logs := []*db.TunsEventLog{}
1888 rs, err := me.Db.Query(
1889 `SELECT id, user_id, server_id, remote_addr, event_type, tunnel_type, connection_type, tunnel_id, created_at
1890 FROM tuns_event_logs WHERE user_id=$1 ORDER BY created_at DESC`, userID)
1891 if err != nil {
1892 return nil, err
1893 }
1894
1895 for rs.Next() {
1896 log := db.TunsEventLog{}
1897 err := rs.Scan(
1898 &log.ID, &log.UserId, &log.ServerID, &log.RemoteAddr,
1899 &log.EventType, &log.TunnelType, &log.ConnectionType,
1900 &log.TunnelID, &log.CreatedAt,
1901 )
1902 if err != nil {
1903 return nil, err
1904 }
1905 logs = append(logs, &log)
1906 }
1907
1908 if rs.Err() != nil {
1909 return nil, rs.Err()
1910 }
1911
1912 return logs, nil
1913}
1914
1915func (me *PsqlDB) FindUserStats(userID string) (*db.UserStats, error) {
1916 stats := db.UserStats{}
1917 rs, err := me.Db.Query(`SELECT cur_space, count(id), min(created_at), max(created_at), max(updated_at) FROM posts WHERE user_id=$1 GROUP BY cur_space`, userID)
1918 if err != nil {
1919 return nil, err
1920 }
1921
1922 for rs.Next() {
1923 stat := db.UserServiceStats{}
1924 err := rs.Scan(&stat.Service, &stat.Num, &stat.FirstCreatedAt, &stat.LastestCreatedAt, &stat.LatestUpdatedAt)
1925 if err != nil {
1926 return nil, err
1927 }
1928 switch stat.Service {
1929 case "prose":
1930 stats.Prose = stat
1931 case "pastes":
1932 stats.Pastes = stat
1933 case "feeds":
1934 stats.Feeds = stat
1935 }
1936 }
1937
1938 if rs.Err() != nil {
1939 return nil, rs.Err()
1940 }
1941
1942 pgs, err := me.findPagesStats(userID)
1943 if err != nil {
1944 return nil, err
1945 }
1946 stats.Pages = *pgs
1947 return &stats, err
1948}