diff --git a/cmd/gochan-migration/internal/pre2021/bans.go b/cmd/gochan-migration/internal/pre2021/bans.go index a7cb25da..40d8bd5c 100644 --- a/cmd/gochan-migration/internal/pre2021/bans.go +++ b/cmd/gochan-migration/internal/pre2021/bans.go @@ -80,7 +80,7 @@ func (m *Pre2021Migrator) migrateBan(tx *sql.Tx, ban *migrationBan, boardID *int migratedBan.Message = ban.reason migratedBan.StaffID = ban.staffID migratedBan.StaffNote = ban.staffNote - if err := gcsql.NewIPBanTx(tx, migratedBan); err != nil { + if err := gcsql.NewIPBan(migratedBan, &gcsql.RequestOptions{Tx: tx}); err != nil { errEv.Err(err).Caller(). Int("oldID", ban.oldID).Msg("Failed to migrate ban") return err diff --git a/cmd/gochan-migration/internal/pre2021/boards.go b/cmd/gochan-migration/internal/pre2021/boards.go index 6b204ad8..8328bad1 100644 --- a/cmd/gochan-migration/internal/pre2021/boards.go +++ b/cmd/gochan-migration/internal/pre2021/boards.go @@ -37,7 +37,7 @@ func (m *Pre2021Migrator) migrateSections() error { } var sectionsToBeCreated []gcsql.Section - rows, err := m.db.QuerySQL(sectionsQuery) + rows, err := m.db.Query(nil, sectionsQuery) if err != nil { errEv.Err(err).Caller().Msg("Failed to query old database sections") return err @@ -126,7 +126,7 @@ func (m *Pre2021Migrator) MigrateBoards() error { } // get boards from old db - rows, err := m.db.QuerySQL(boardsQuery) + rows, err := m.db.Query(nil, boardsQuery) if err != nil { errEv.Err(err).Caller().Msg("Failed to query old database boards") return err @@ -162,7 +162,7 @@ func (m *Pre2021Migrator) MigrateBoards() error { Int("migratedBoardID", newBoard.ID). Msg("Board already exists in new db, updating values") // don't update other values in the array since they don't affect migrating threads or posts - if _, err = gcsql.ExecSQL(`UPDATE DBPREFIXboards + if _, err = gcsql.Exec(nil, `UPDATE DBPREFIXboards SET uri = ?, navbar_position = ?, title = ?, subtitle = ?, description = ?, max_file_size = ?, max_threads = ?, default_style = ?, locked = ?, anonymous_name = ?, force_anonymous = ?, autosage_after = ?, no_images_after = ?, max_message_length = ?, diff --git a/cmd/gochan-migration/internal/pre2021/posts.go b/cmd/gochan-migration/internal/pre2021/posts.go index bd4bfa56..3fa474fc 100644 --- a/cmd/gochan-migration/internal/pre2021/posts.go +++ b/cmd/gochan-migration/internal/pre2021/posts.go @@ -1,7 +1,6 @@ package pre2021 import ( - "context" "database/sql" "time" @@ -35,10 +34,10 @@ type migrationPost struct { func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *zerolog.Event) error { var err error - + opts := &gcsql.RequestOptions{Tx: tx} if post.oldParentID == 0 { // migrating post was a thread OP, create the row in the threads table - if post.ThreadID, err = gcsql.CreateThread(tx, post.boardID, false, post.stickied, post.autosage, false); err != nil { + if post.ThreadID, err = gcsql.CreateThread(opts, post.boardID, false, post.stickied, post.autosage, false); err != nil { errEv.Err(err).Caller(). Int("boardID", post.boardID). Msg("Failed to create thread") @@ -46,7 +45,7 @@ func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *ze } // insert thread top post - if err = post.InsertWithContext(context.Background(), tx, true, post.boardID, false, post.stickied, post.autosage, false); err != nil { + if err = post.Insert(true, post.boardID, false, post.stickied, post.autosage, false, opts); err != nil { errEv.Err(err).Caller(). Int("boardID", post.boardID). Int("threadID", post.ThreadID). @@ -54,7 +53,7 @@ func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *ze } if post.filename != "" { - if err = post.AttachFileTx(tx, &gcsql.Upload{ + if err = post.AttachFile(&gcsql.Upload{ PostID: post.ID, OriginalFilename: post.filenameOriginal, Filename: post.filename, @@ -64,7 +63,7 @@ func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *ze ThumbnailHeight: post.thumbH, Width: post.imageW, Height: post.imageH, - }); err != nil { + }, opts); err != nil { errEv.Err(err).Caller(). Int("oldPostID", post.oldID). Msg("Failed to attach upload to migrated post") @@ -85,7 +84,7 @@ func (m *Pre2021Migrator) MigratePosts() error { } defer tx.Rollback() - rows, err := m.db.QuerySQL(threadsQuery) + rows, err := m.db.Query(nil, threadsQuery) if err != nil { errEv.Err(err).Caller().Msg("Failed to get threads") return err @@ -126,7 +125,7 @@ func (m *Pre2021Migrator) MigratePosts() error { } // get and insert replies - replyRows, err := m.db.QuerySQL(postsQuery+" AND parentid = ?", thread.oldID) + replyRows, err := m.db.Query(nil, postsQuery+" AND parentid = ?", thread.oldID) if err != nil { errEv.Err(err).Caller(). Int("parentID", thread.oldID). @@ -156,7 +155,7 @@ func (m *Pre2021Migrator) MigratePosts() error { } if thread.locked { - if _, err = gcsql.ExecTxSQL(tx, "UPDATE DBPREFIXthreads SET locked = TRUE WHERE id = ?", thread.ThreadID); err != nil { + if _, err = gcsql.Exec(&gcsql.RequestOptions{Tx: tx}, "UPDATE DBPREFIXthreads SET locked = TRUE WHERE id = ?", thread.ThreadID); err != nil { errEv.Err(err).Caller(). Int("threadID", thread.ThreadID). Msg("Unable to re-lock migrated thread") diff --git a/pkg/gcsql/bans.go b/pkg/gcsql/bans.go index 2632fcd3..6bb1f086 100644 --- a/pkg/gcsql/bans.go +++ b/pkg/gcsql/bans.go @@ -25,38 +25,42 @@ type Ban interface { Deactivate(int) error } -func NewIPBanTx(tx *sql.Tx, ban *IPBan) error { +func NewIPBan(ban *IPBan, requestOpts ...*RequestOptions) error { const query = `INSERT INTO DBPREFIXip_ban (staff_id, board_id, banned_for_post_id, copy_post_text, is_thread_ban, is_active, range_start, range_end, appeal_at, expires_at, permanent, staff_note, message, can_appeal) VALUES(?, ?, ?, ?, ?, ?, PARAM_ATON, PARAM_ATON, ?, ?, ?, ?, ?, ?)` + opts := setupOptions(requestOpts...) + shouldCommit := opts.Tx == nil + var err error + if shouldCommit { + opts.Tx, err = BeginTx() + if err != nil { + return err + } + defer opts.Tx.Rollback() + } + if ban.ID > 0 { return ErrBanAlreadyInserted } - _, err := ExecTxSQL(tx, query, ban.StaffID, ban.BoardID, ban.BannedForPostID, ban.CopyPostText, + if _, err = Exec(opts, query, ban.StaffID, ban.BoardID, ban.BannedForPostID, ban.CopyPostText, ban.IsThreadBan, ban.IsActive, ban.RangeStart, ban.RangeEnd, ban.AppealAt, - ban.ExpiresAt, ban.Permanent, ban.StaffNote, ban.Message, ban.CanAppeal) + ban.ExpiresAt, ban.Permanent, ban.StaffNote, ban.Message, ban.CanAppeal, + ); err != nil { + return err + } + + ban.ID, err = getLatestID(opts, "DBPREFIXip_ban") if err != nil { return err } - - ban.ID, err = getLatestID("DBPREFIXip_ban", tx) - return err -} - -func NewIPBan(ban *IPBan) error { - tx, err := BeginTx() - if err != nil { - return err - } - defer tx.Rollback() - - if err = NewIPBanTx(tx, ban); err != nil { - return err + if shouldCommit { + return opts.Tx.Commit() } - return tx.Commit() + return nil } // CheckIPBan returns the latest active IP ban for the given IP, as well as any @@ -74,7 +78,7 @@ func CheckIPBan(ip string, boardID int) (*IPBan, error) { (expires_at > CURRENT_TIMESTAMP OR permanent) ORDER BY id DESC LIMIT 1` var ban IPBan - err := QueryRowSQL(query, []any{ip, ip, boardID}, []any{ + err := QueryRow(nil, query, []any{ip, ip, boardID}, []any{ &ban.ID, &ban.StaffID, &ban.BoardID, &ban.BannedForPostID, &ban.CopyPostText, &ban.IsThreadBan, &ban.IsActive, &ban.RangeStart, &ban.RangeEnd, &ban.IssuedAt, &ban.AppealAt, &ban.ExpiresAt, &ban.Permanent, &ban.StaffNote, &ban.Message, @@ -90,7 +94,7 @@ func CheckIPBan(ip string, boardID int) (*IPBan, error) { func GetIPBanByID(id int) (*IPBan, error) { const query = ipBanQueryBase + " WHERE id = ?" var ban IPBan - err := QueryRowSQL(query, []any{id}, []any{ + err := QueryRow(nil, query, []any{id}, []any{ &ban.ID, &ban.StaffID, &ban.BoardID, &ban.BannedForPostID, &ban.CopyPostText, &ban.IsThreadBan, &ban.IsActive, &ban.RangeStart, &ban.RangeEnd, &ban.IssuedAt, &ban.AppealAt, &ban.ExpiresAt, &ban.Permanent, &ban.StaffNote, &ban.Message, @@ -110,14 +114,15 @@ func GetIPBans(boardID int, limit int, onlyActive bool) ([]IPBan, error) { var rows *sql.Rows var err error if boardID > 0 { - rows, err = QuerySQL(query, boardID) + rows, err = Query(nil, query, boardID) } else { - rows, err = QuerySQL(query) + rows, err = Query(nil, query) } if err != nil { return nil, err } var bans []IPBan + defer rows.Close() for rows.Next() { var ban IPBan if err = rows.Scan( @@ -125,7 +130,6 @@ func GetIPBans(boardID int, limit int, onlyActive bool) ([]IPBan, error) { &ban.IsActive, &ban.RangeStart, &ban.RangeEnd, &ban.IssuedAt, &ban.AppealAt, &ban.ExpiresAt, &ban.Permanent, &ban.StaffNote, &ban.Message, &ban.CanAppeal, ); err != nil { - rows.Close() return nil, err } if onlyActive && !ban.IsActive { @@ -138,7 +142,7 @@ func GetIPBans(boardID int, limit int, onlyActive bool) ([]IPBan, error) { func (ipb *IPBan) Appeal(msg string) error { const query = `INSERT INTO DBPREFIXip_ban_appeals (ip_ban_id, appeal_text, is_denied) VALUES(?, ?, FALSE)` - _, err := ExecSQL(query, ipb.ID, msg) + _, err := Exec(nil, query, ipb.ID, msg) return err } @@ -163,10 +167,10 @@ func (ipb *IPBan) Deactivate(_ int) error { return err } defer tx.Rollback() - if _, err = ExecTxSQL(tx, deactivateQuery, ipb.ID); err != nil { + if _, err = Exec(&RequestOptions{Tx: tx}, deactivateQuery, ipb.ID); err != nil { return err } - if _, err = ExecTxSQL(tx, auditInsertQuery, ipb.ID); err != nil { + if _, err = Exec(&RequestOptions{Tx: tx}, auditInsertQuery, ipb.ID); err != nil { return err } return tx.Commit() diff --git a/pkg/gcsql/boards.go b/pkg/gcsql/boards.go index 744a29d2..87cbb281 100644 --- a/pkg/gcsql/boards.go +++ b/pkg/gcsql/boards.go @@ -459,7 +459,7 @@ func (board *Board) ModifyInDB() error { require_file = ?, enable_catalog = ? WHERE id = ?` - _, err := ExecSQL(query, + _, err := Exec(nil, query, board.SectionID, board.NavbarPosition, board.Title, board.Subtitle, board.Description, board.MaxFilesize, board.MaxThreads, board.DefaultStyle, board.Locked, board.AnonymousName, board.ForceAnonymous, board.AutosageAfter, board.NoImagesAfter, board.MaxMessageLength, diff --git a/pkg/gcsql/database.go b/pkg/gcsql/database.go index f9cdb51a..50d633f0 100644 --- a/pkg/gcsql/database.go +++ b/pkg/gcsql/database.go @@ -118,21 +118,16 @@ func (db *GCDB) PrepareContextSQL(ctx context.Context, query string, tx *sql.Tx) return stmt, sqlVersionError(err, db.driver, &prepared) } -/* -ExecSQL executes the given SQL statement with the given parameters -Example: - - var intVal int - var stringVal string - result, err := db.ExecSQL("INSERT INTO tablename (intval,stringval) VALUES(?,?)", intVal, stringVal) -*/ -func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) { - stmt, err := db.PrepareSQL(query, nil) +// Exec executes the given SQL statement with the given parameters, optionally with the given RequestOptions struct +// or a background context and transaction if nil +func (db *GCDB) Exec(opts *RequestOptions, query string, values ...any) (sql.Result, error) { + opts = setupOptions(opts) + stmt, err := db.PrepareContextSQL(opts.Context, query, opts.Tx) if err != nil { return nil, err } defer stmt.Close() - result, err := stmt.Exec(values...) + result, err := stmt.ExecContext(opts.Context, values...) if err != nil { return nil, err } @@ -140,7 +135,22 @@ func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) { } /* -ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil) +ExecSQL executes the given SQL statement with the given parameters. +Deprecated: Use Exec instead + +Example: + + var intVal int + var stringVal string + result, err := db.ExecSQL("INSERT INTO tablename (intval,stringval) VALUES(?,?)", intVal, stringVal) +*/ +func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) { + return db.Exec(nil, query, values...) +} + +/* +ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil). +Deprecated: Use Exec instead, with a RequestOptions struct for the context and transaction Example: @@ -152,21 +162,12 @@ Example: intVal, stringVal) */ func (db *GCDB) ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) { - stmt, err := db.PrepareContextSQL(ctx, sqlStr, tx) - if err != nil { - return nil, err - } - defer stmt.Close() - - result, err := stmt.ExecContext(ctx, values...) - if err != nil { - return nil, err - } - return result, stmt.Close() + return db.Exec(&RequestOptions{Context: ctx, Tx: tx}, sqlStr, values...) } /* -ExecTxSQL executes the given SQL statemtnt, optionally with the given transaction (if non-nil) +ExecTxSQL executes the given SQL statemtnt, optionally with the given transaction (if non-nil). +Deprecated: Use Exec instead, with a RequestOptions struct for the transaction Example: @@ -179,16 +180,7 @@ Example: intVal, stringVal) */ func (db *GCDB) ExecTxSQL(tx *sql.Tx, query string, values ...any) (sql.Result, error) { - stmt, err := db.PrepareSQL(query, tx) - if err != nil { - return nil, err - } - defer stmt.Close() - res, err := stmt.Exec(values...) - if err != nil { - return res, err - } - return res, stmt.Close() + return db.Exec(&RequestOptions{Tx: tx}, query, values...) } /* @@ -209,9 +201,25 @@ func (db *GCDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, erro return db.db.BeginTx(ctx, opts) } +// QueryRow gets a row from the db with the values in values[] and fills the respective pointers in out[], +// with an optional RequestOptions struct for the context and transaction +func (db *GCDB) QueryRow(opts *RequestOptions, query string, values []any, out []any) error { + opts = setupOptions(opts) + stmt, err := db.PrepareContextSQL(opts.Context, query, opts.Tx) + if err != nil { + return err + } + defer stmt.Close() + if err = stmt.QueryRowContext(opts.Context, values...).Scan(out...); err != nil { + return err + } + return stmt.Close() +} + /* -QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[] -Automatically escapes the given values and caches the query +QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]. +Deprecated: Use QueryRow instead + Example: id := 32 @@ -222,12 +230,7 @@ Example: []any{&intVal, &stringVal}) */ func (db *GCDB) QueryRowSQL(query string, values, out []any) error { - stmt, err := db.PrepareSQL(query, nil) - if err != nil { - return err - } - defer stmt.Close() - return stmt.QueryRow(values...).Scan(out...) + return db.QueryRow(nil, query, values, out) } /* @@ -244,16 +247,7 @@ Example: []any{id}, []any{&name}) */ func (db *GCDB) QueryRowContextSQL(ctx context.Context, tx *sql.Tx, query string, values, out []any) error { - stmt, err := db.PrepareContextSQL(ctx, query, tx) - if err != nil { - return err - } - defer stmt.Close() - - if err = stmt.QueryRowContext(ctx, values...).Scan(out...); err != nil { - return err - } - return stmt.Close() + return db.QueryRow(&RequestOptions{Context: ctx, Tx: tx}, query, values, out) } /* @@ -273,21 +267,29 @@ Example: []any{&intVal, &stringVal}) */ func (db *GCDB) QueryRowTxSQL(tx *sql.Tx, query string, values, out []any) error { - stmt, err := db.PrepareSQL(query, tx) + return db.QueryRow(&RequestOptions{Tx: tx}, query, values, out) +} + +// Query sends the query to the database with the given options (or a background context if nil), and the given parameters +func (db *GCDB) Query(opts *RequestOptions, query string, a ...any) (*sql.Rows, error) { + opts = setupOptions(opts) + stmt, err := db.PrepareContextSQL(opts.Context, query, opts.Tx) if err != nil { - return err + return nil, err } defer stmt.Close() - if err = stmt.QueryRow(values...).Scan(out...); err != nil { - return err + rows, err := stmt.QueryContext(opts.Context, a...) + if err != nil { + return rows, err } - return stmt.Close() + + return rows, stmt.Close() } /* -QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[] -Automatically escapes the given values and caches the query +QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]. +Deprecated: Use Query instead Example: rows, err := db.QuerySQL("SELECT * FROM table") @@ -301,17 +303,13 @@ Example: } */ func (db *GCDB) QuerySQL(query string, a ...any) (*sql.Rows, error) { - stmt, err := db.PrepareSQL(query, nil) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.Query(a...) + return db.Query(nil, query, a...) } /* QueryContextSQL queries the database with a prepared statement and the given parameters, using the given context -for a deadline +for a deadline. +Deprecated: Use Query instead, with a RequestOptions struct for the context and transaction Example: @@ -320,22 +318,12 @@ Example: rows, err := db.QueryContextSQL(ctx, nil, "SELECT name from posts where NOT is_deleted") */ func (db *GCDB) QueryContextSQL(ctx context.Context, tx *sql.Tx, query string, a ...any) (*sql.Rows, error) { - stmt, err := db.PrepareContextSQL(ctx, query, tx) - if err != nil { - return nil, err - } - defer stmt.Close() - - rows, err := stmt.QueryContext(ctx, a...) - if err != nil { - return rows, err - } - return rows, stmt.Close() + return db.Query(&RequestOptions{Context: ctx, Tx: tx}, query, a...) } /* -QueryTxSQL gets all rows from the db with the values in values[] and fills the respective pointers in out[] -Automatically escapes the given values and caches the query +QueryTxSQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]. +Deprecated: Use Query instead, with a RequestOptions struct for the transaction Example: tx, _ := db.Begin() @@ -350,16 +338,7 @@ Example: } */ func (db *GCDB) QueryTxSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, error) { - stmt, err := db.PrepareSQL(query, tx) - if err != nil { - return nil, err - } - defer stmt.Close() - rows, err := stmt.Query(a...) - if err != nil { - return nil, err - } - return rows, stmt.Close() + return db.Query(&RequestOptions{Tx: tx}, query, a...) } func setupDBConn(cfg *config.SQLConfig) (db *GCDB, err error) { @@ -404,6 +383,7 @@ func setupSqlTestConfig(dbDriver string, dbName string, dbPrefix string) *config } } +// SetupMockDB sets up a mock database connection for testing func SetupMockDB(driver string) (sqlmock.Sqlmock, error) { var err error gcdb, err = setupDBConn(setupSqlTestConfig(driver, "gochan", "")) diff --git a/pkg/gcsql/posts.go b/pkg/gcsql/posts.go index a8493d97..96479c1c 100644 --- a/pkg/gcsql/posts.go +++ b/pkg/gcsql/posts.go @@ -39,23 +39,24 @@ func GetPostFromID(id int, onlyNotDeleted bool) (*Post, error) { query += " AND is_deleted = FALSE" } post := new(Post) - err := QueryRowSQL(query, []any{id}, []any{ + err := QueryRow(nil, query, []any{id}, []any{ &post.ID, &post.ThreadID, &post.IsTopPost, &post.IP, &post.CreatedOn, &post.Name, &post.Tripcode, &post.IsRoleSignature, &post.Email, &post.Subject, &post.Message, &post.MessageRaw, &post.Password, &post.DeletedAt, &post.IsDeleted, &post.BannedMessage, &post.Flag, &post.Country, }) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, ErrPostDoesNotExist - + } else if err != nil { + return nil, err } - return post, err + return post, nil } func GetPostIP(postID int) (string, error) { sql := "SELECT IP_NTOA FROM DBPREFIXposts WHERE id = ?" var ip string - err := QueryRowSQL(sql, []any{postID}, []any{&ip}) + err := QueryRow(nil, sql, []any{postID}, []any{&ip}) return ip, err } @@ -68,7 +69,7 @@ func GetPostsFromIP(ip string, limit int, onlyNotDeleted bool) ([]Post, error) { } sql += " ORDER BY id DESC LIMIT ?" - rows, err := QuerySQL(sql, ip, limit) + rows, err := Query(nil, sql, ip, limit) if err != nil { return nil, err } @@ -109,7 +110,7 @@ func GetTopPostIDsInThreadIDs(threads ...any) (map[any]int, error) { } params := createArrayPlaceholder(threads) query := `SELECT id FROM DBPREFIXposts WHERE thread_id in ` + params + " AND is_top_post" - rows, err := QuerySQL(query, threads...) + rows, err := Query(nil, query, threads...) if err != nil { return nil, err } @@ -130,7 +131,7 @@ func GetTopPostIDsInThreadIDs(threads ...any) (map[any]int, error) { func GetThreadTopPost(threadID int) (*Post, error) { const query = selectPostsBaseSQL + "WHERE thread_id = ? AND is_top_post = TRUE LIMIT 1" post := new(Post) - err := QueryRowSQL(query, []any{threadID}, []any{ + err := QueryRow(nil, query, []any{threadID}, []any{ &post.ID, &post.ThreadID, &post.IsTopPost, &post.IP, &post.CreatedOn, &post.Name, &post.Tripcode, &post.IsRoleSignature, &post.Email, &post.Subject, &post.Message, &post.MessageRaw, &post.Password, &post.DeletedAt, &post.IsDeleted, @@ -180,19 +181,23 @@ func GetBoardTopPosts[B intOrStringConstraint](board B) ([]*Post, error) { func GetPostPassword(id int) (string, error) { const query = `SELECT password FROM DBPREFIXposts WHERE id = ?` var passwordChecksum string - err := QueryRowSQL(query, []any{id}, []any{&passwordChecksum}) + err := QueryRow(nil, query, []any{id}, []any{&passwordChecksum}) return passwordChecksum, err } // PermanentlyRemoveDeletedPosts removes all posts and files marked as deleted from the database -func PermanentlyRemoveDeletedPosts() error { +func PermanentlyRemoveDeletedPosts(opts ...*RequestOptions) error { const sql1 = `DELETE FROM DBPREFIXposts WHERE is_deleted` const sql2 = `DELETE FROM DBPREFIXthreads WHERE is_deleted` - _, err := ExecSQL(sql1) + var useOpts *RequestOptions + if len(opts) > 0 { + useOpts = opts[0] + } + _, err := Exec(useOpts, sql1) if err != nil { return err } - _, err = ExecSQL(sql2) + _, err = Exec(useOpts, sql2) return err } @@ -201,7 +206,7 @@ func PermanentlyRemoveDeletedPosts() error { func SinceLastPost(postIP string) (int, error) { const query = `SELECT COALESCE(MAX(created_on), '1970-01-01 00:00:00') FROM DBPREFIXposts WHERE ip = ?` var whenStr string - err := QueryRowSQL(query, []any{postIP}, []any{&whenStr}) + err := QueryRow(nil, query, []any{postIP}, []any{&whenStr}) if err != nil { return -1, err } @@ -219,7 +224,7 @@ func SinceLastThread(postIP string) (int, error) { const query = `SELECT COALESCE(MAX(created_on), '1970-01-01 00:00:00') FROM DBPREFIXposts WHERE ip = ? AND is_top_post` var whenStr string - err := QueryRowSQL(query, []any{postIP}, []any{&whenStr}) + err := QueryRow(nil, query, []any{postIP}, []any{&whenStr}) if err != nil { return -1, err } @@ -233,7 +238,7 @@ func SinceLastThread(postIP string) (int, error) { // UpdateContents updates the email, subject, and message text of the post func (p *Post) UpdateContents(email string, subject string, message template.HTML, messageRaw string) error { const sqlUpdate = `UPDATE DBPREFIXposts SET email = ?, subject = ?, message = ?, message_raw = ? WHERE ID = ?` - _, err := ExecSQL(sqlUpdate, email, subject, message, messageRaw, p.ID) + _, err := Exec(nil, sqlUpdate, email, subject, message, messageRaw, p.ID) if err != nil { return err } @@ -244,27 +249,27 @@ func (p *Post) UpdateContents(email string, subject string, message template.HTM return nil } -func (p *Post) GetBoardID() (int, error) { +func (p *Post) GetBoardID(opts ...*RequestOptions) (int, error) { const query = `SELECT board_id FROM DBPREFIXthreads where id = ?` var boardID int - err := QueryRowSQL(query, []any{p.ThreadID}, []any{&boardID}) + err := QueryRow(setupOptions(opts...), query, []any{p.ThreadID}, []any{&boardID}) if errors.Is(err, sql.ErrNoRows) { err = ErrBoardDoesNotExist } return boardID, err } -func (p *Post) GetBoardDir() (string, error) { +func (p *Post) GetBoardDir(opts ...*RequestOptions) (string, error) { const query = "SELECT dir FROM DBPREFIXboards" + boardFromPostIdSuffixSQL var dir string - err := QueryRowSQL(query, []any{p.ID}, []any{&dir}) + err := QueryRow(setupOptions(opts...), query, []any{p.ID}, []any{&dir}) return dir, err } -func (p *Post) GetBoard() (*Board, error) { +func (p *Post) GetBoard(opts ...*RequestOptions) (*Board, error) { const query = selectBoardsBaseSQL + boardFromPostIdSuffixSQL board := new(Board) - err := QueryRowSQL(query, []any{p.ID}, []any{ + err := QueryRow(setupOptions(opts...), query, []any{p.ID}, []any{ &board.ID, &board.SectionID, &board.URI, &board.Dir, &board.NavbarPosition, &board.Title, &board.Subtitle, &board.Description, &board.MaxFilesize, &board.MaxThreads, &board.DefaultStyle, &board.Locked, &board.CreatedAt, &board.AnonymousName, &board.ForceAnonymous, &board.AutosageAfter, &board.NoImagesAfter, @@ -284,13 +289,13 @@ func (p *Post) ChangeBoardID(newBoardID int) error { } // TopPostID returns the OP post ID of the thread that p is in -func (p *Post) TopPostID() (int, error) { +func (p *Post) TopPostID(opts ...*RequestOptions) (int, error) { if p.IsTopPost { return p.ID, nil } const query = `SELECT id FROM DBPREFIXposts WHERE thread_id = ? and is_top_post = TRUE ORDER BY id ASC LIMIT 1` var topPostID int - err := QueryRowSQL(query, []any{p.ThreadID}, []any{&topPostID}) + err := QueryRow(setupOptions(opts...), query, []any{p.ThreadID}, []any{&topPostID}) return topPostID, err } @@ -306,13 +311,13 @@ func (p *Post) GetTopPost() (*Post, error) { // GetPostUpload returns the upload info associated with the file as well as any errors encountered. // If the file has no uploads, then *Upload is nil. If the file was removed from the post, then Filename // and OriginalFilename = "deleted" -func (p *Post) GetUpload() (*Upload, error) { +func (p *Post) GetUpload(opts ...*RequestOptions) (*Upload, error) { const query = `SELECT id, post_id, file_order, original_filename, filename, checksum, file_size, is_spoilered, thumbnail_width, thumbnail_height, width, height FROM DBPREFIXfiles WHERE post_id = ?` upload := new(Upload) - err := QueryRowSQL(query, []any{p.ID}, []any{ + err := QueryRow(setupOptions(opts...), query, []any{p.ID}, []any{ &upload.ID, &upload.PostID, &upload.FileOrder, &upload.OriginalFilename, &upload.Filename, &upload.Checksum, &upload.FileSize, &upload.IsSpoilered, &upload.ThumbnailWidth, &upload.ThumbnailHeight, &upload.Width, &upload.Height, }) @@ -325,7 +330,7 @@ func (p *Post) GetUpload() (*Upload, error) { // UnlinkUploads disassociates the post with any uploads in DBPREFIXfiles // that may have been uploaded with it, optionally leaving behind a "File Deleted" // frame where the thumbnail appeared -func (p *Post) UnlinkUploads(leaveDeletedBox bool) error { +func (p *Post) UnlinkUploads(leaveDeletedBox bool, requestOpts ...*RequestOptions) error { var sqlStr string if leaveDeletedBox { // leave a "File Deleted" box @@ -333,7 +338,7 @@ func (p *Post) UnlinkUploads(leaveDeletedBox bool) error { } else { sqlStr = `DELETE FROM DBPREFIXfiles WHERE post_id = ?` } - _, err := ExecSQL(sqlStr, p.ID) + _, err := Exec(setupOptions(requestOpts...), sqlStr, p.ID) return err } @@ -348,17 +353,24 @@ func (p *Post) InCyclicThread() (bool, error) { } // Delete sets the post as deleted and sets the deleted_at timestamp to the current time -func (p *Post) Delete() error { - ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout) - defer cancel() - tx, err := BeginContextTx(ctx) - if err != nil { - return err +func (p *Post) Delete(requestOptions ...*RequestOptions) error { + shouldCommit := len(requestOptions) == 0 + opts := setupOptions(requestOptions...) + if opts.Context == context.Background() { + opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout) + defer opts.Cancel() + } + var err error + if opts.Tx == nil { + opts.Tx, err = BeginContextTx(opts.Context) + if err != nil { + return err + } + defer opts.Tx.Rollback() } - defer tx.Rollback() var rowCount int - err = QueryRowContextSQL(ctx, tx, "SELECT COUNT(*) FROM DBPREFIXposts WHERE id = ?", []any{p.ID}, []any{&rowCount}) + err = QueryRow(opts, "SELECT COUNT(*) FROM DBPREFIXposts WHERE id = ?", []any{p.ID}, []any{&rowCount}) if errors.Is(err, sql.ErrNoRows) { err = ErrPostDoesNotExist } @@ -367,16 +379,33 @@ func (p *Post) Delete() error { } if p.IsTopPost { - return deleteThread(ctx, tx, p.ThreadID) + return deleteThread(opts, p.ThreadID) } - if _, err = ExecContextSQL(ctx, tx, "UPDATE DBPREFIXposts SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = ?", p.ID); err != nil { + if _, err = Exec(opts, "UPDATE DBPREFIXposts SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = ?", p.ID); err != nil { return err } - return tx.Commit() + if shouldCommit { + return opts.Tx.Commit() + } + return nil } -// InsertWithContext inserts the post into the database with the given context and transaction -func (p *Post) InsertWithContext(ctx context.Context, tx *sql.Tx, bumpThread bool, boardID int, locked bool, stickied bool, anchored bool, cyclical bool) error { +// Insert inserts the post into the database with the optional given options +func (p *Post) Insert(bumpThread bool, boardID int, locked bool, stickied bool, anchored bool, cyclical bool, requestOptions ...*RequestOptions) error { + opts := setupOptions(requestOptions...) + if len(requestOptions) == 0 { + opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout) + defer opts.Cancel() + } + var err error + if opts.Tx == nil { + opts.Tx, err = BeginContextTx(opts.Context) + if err != nil { + return err + } + defer opts.Tx.Rollback() + } + if p.ID > 0 { // already inserted return ErrorPostAlreadySent @@ -387,19 +416,18 @@ func (p *Post) InsertWithContext(ctx context.Context, tx *sql.Tx, bumpThread boo VALUES(?,?,PARAM_ATON,CURRENT_TIMESTAMP,?,?,?,?,?,?,?,?,?,?)` bumpSQL := `UPDATE DBPREFIXthreads SET last_bump = CURRENT_TIMESTAMP WHERE id = ?` - var err error if p.ThreadID == 0 { // thread doesn't exist yet, this is a new post p.IsTopPost = true var threadID int - threadID, err = CreateThread(tx, boardID, locked, stickied, anchored, cyclical) + threadID, err = CreateThread(opts, boardID, locked, stickied, anchored, cyclical) if err != nil { return err } p.ThreadID = threadID } else { var threadIsLocked bool - if err = QueryRowTxSQL(tx, "SELECT locked FROM DBPREFIXthreads WHERE id = ?", + if err = QueryRow(opts, "SELECT locked FROM DBPREFIXthreads WHERE id = ?", []any{p.ThreadID}, []any{&threadIsLocked}); err != nil { return err } @@ -408,40 +436,26 @@ func (p *Post) InsertWithContext(ctx context.Context, tx *sql.Tx, bumpThread boo } } - if _, err = ExecContextSQL(ctx, tx, insertSQL, + if _, err = Exec(opts, insertSQL, p.ThreadID, p.IsTopPost, p.IP, p.Name, p.Tripcode, p.IsRoleSignature, p.Email, p.Subject, p.Message, p.MessageRaw, p.Password, p.Flag, p.Country, ); err != nil { return err } - if p.ID, err = getLatestID("DBPREFIXposts", tx); err != nil { + if p.ID, err = getLatestID(opts, "DBPREFIXposts"); err != nil { return err } if bumpThread { - if _, err = ExecContextSQL(ctx, tx, bumpSQL, p.ThreadID); err != nil { + if _, err = Exec(opts, bumpSQL, p.ThreadID); err != nil { return err } } + if len(requestOptions) == 0 { + return opts.Tx.Commit() + } return nil } -func (p *Post) Insert(bumpThread bool, boardID int, locked bool, stickied bool, anchored bool, cyclical bool) error { - ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout) - defer cancel() - - tx, err := BeginContextTx(ctx) - if err != nil { - return err - } - defer tx.Rollback() - - if err = p.InsertWithContext(ctx, tx, bumpThread, boardID, locked, stickied, anchored, cyclical); err != nil { - return err - } - - return tx.Commit() -} - // CyclicThreadPost represents a post that should be deleted in a cyclic thread type CyclicThreadPost struct { PostID int // sql: post_id @@ -515,7 +529,7 @@ func (p *Post) WebPath() string { webRoot := config.GetSystemCriticalConfig().WebRoot const query = "SELECT op_id, dir FROM DBPREFIXv_top_post_board_dir WHERE id = ?" - err := QueryRowSQL(query, []any{p.ID}, []any{&p.opID, &p.boardDir}) + err := QueryRow(nil, query, []any{p.ID}, []any{&p.opID, &p.boardDir}) if err != nil { return webRoot } diff --git a/pkg/gcsql/preload.go b/pkg/gcsql/preload.go index 401851a4..3312577b 100644 --- a/pkg/gcsql/preload.go +++ b/pkg/gcsql/preload.go @@ -39,7 +39,7 @@ func PreloadModule(l *lua.LState) int { }) } - rows, err := QuerySQL(queryStr, queryArgs...) + rows, err := Query(nil, queryStr, queryArgs...) l.Push(luar.New(l, rows)) l.Push(luar.New(l, err)) @@ -57,7 +57,7 @@ func PreloadModule(l *lua.LState) int { execArgs = append(execArgs, arg) }) } - result, err := ExecSQL(execStr) + result, err := Exec(nil, execStr) l.Push(luar.New(l, result)) l.Push(luar.New(l, err)) diff --git a/pkg/gcsql/sections.go b/pkg/gcsql/sections.go index 1a82a74b..ccfa9a0e 100644 --- a/pkg/gcsql/sections.go +++ b/pkg/gcsql/sections.go @@ -96,7 +96,7 @@ func GetSectionFromName(name string) (*Section, error) { // DeleteSection deletes a section from the database and resets the AllSections array func DeleteSection(id int) error { const query = `DELETE FROM DBPREFIXsections WHERE id = ?` - _, err := ExecSQL(query, id) + _, err := Exec(nil, query, id) if err != nil { return err } @@ -105,38 +105,47 @@ func DeleteSection(id int) error { // NewSection creates a new board section in the database and returns a *Section struct pointer. // If position < 0, it will use the ID -func NewSection(name string, abbreviation string, hidden bool, position int) (*Section, error) { +func NewSection(name string, abbreviation string, hidden bool, position int, requestOpts ...*RequestOptions) (*Section, error) { const sqlINSERT = `INSERT INTO DBPREFIXsections (name, abbreviation, hidden, position) VALUES (?,?,?,?)` const sqlPosition = `SELECT COALESCE(MAX(position) + 1, 1) FROM DBPREFIXsections` - - tx, err := BeginTx() - if err != nil { - return nil, err + var opts *RequestOptions + var err error + shouldCommit := len(requestOpts) == 0 + if shouldCommit { + opts = &RequestOptions{} + opts.Tx, err = BeginTx() + if err != nil { + return nil, err + } + opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout) + defer func() { + opts.Cancel() + opts.Tx.Rollback() + }() + } else { + opts = requestOpts[0] } - ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout) - defer func() { - cancel() - tx.Rollback() - }() if position < 0 { // position not specified - err = QueryRowContextSQL(ctx, tx, sqlPosition, nil, []any{&position}) + err = QueryRow(opts, sqlPosition, nil, []any{&position}) if errors.Is(err, sql.ErrNoRows) { position = 1 } else if err != nil { return nil, err } } - if _, err = ExecContextSQL(ctx, tx, sqlINSERT, name, abbreviation, hidden, position); err != nil { + if _, err = Exec(opts, sqlINSERT, name, abbreviation, hidden, position); err != nil { return nil, err } - id, err := getLatestID("DBPREFIXsections", tx) + id, err := getLatestID(opts, "DBPREFIXsections") if err != nil { return nil, err } - if err = tx.Commit(); err != nil { - return nil, err + if shouldCommit { + if err = opts.Tx.Commit(); err != nil { + return nil, err + } } return &Section{ ID: id, @@ -147,9 +156,15 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S }, nil } -func (s *Section) UpdateValues() error { +func (s *Section) UpdateValues(requestOpts ...*RequestOptions) error { + opts := setupOptions(requestOpts...) + if opts.Context == context.Background() { + opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout) + defer opts.Cancel() + } + var count int - err := QueryRowTimeoutSQL(nil, `SELECT COUNT(*) FROM DBPREFIXsections WHERE id = ?`, []any{s.ID}, []any{&count}) + err := QueryRow(opts, `SELECT COUNT(*) FROM DBPREFIXsections WHERE id = ?`, []any{s.ID}, []any{&count}) if errors.Is(err, sql.ErrNoRows) { return ErrSectionDoesNotExist } @@ -157,6 +172,6 @@ func (s *Section) UpdateValues() error { return err } const query = `UPDATE DBPREFIXsections set name = ?, abbreviation = ?, position = ?, hidden = ? WHERE id = ?` - _, err = ExecTimeoutSQL(nil, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID) + _, err = Exec(opts, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID) return err } diff --git a/pkg/gcsql/setup_test.go b/pkg/gcsql/setup_test.go index 2ecb059c..9bef7237 100644 --- a/pkg/gcsql/setup_test.go +++ b/pkg/gcsql/setup_test.go @@ -171,7 +171,7 @@ func setupAndProvisionMockDB(t *testing.T, mock sqlmock.Sqlmock, dbType string, ExpectQuery().WithArgs("test"). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - _, err := ExecSQL("CREATE DATABASE gochan") + _, err := Exec(nil, "CREATE DATABASE gochan") if err != nil { return err } diff --git a/pkg/gcsql/threads.go b/pkg/gcsql/threads.go index ae02a980..99f65446 100644 --- a/pkg/gcsql/threads.go +++ b/pkg/gcsql/threads.go @@ -1,7 +1,6 @@ package gcsql import ( - "context" "database/sql" "errors" "fmt" @@ -21,27 +20,27 @@ var ( ) // CreateThread creates a new thread in the database with the given board ID and statuses -func CreateThread(tx *sql.Tx, boardID int, locked bool, stickied bool, anchored bool, cyclic bool) (threadID int, err error) { +func CreateThread(requestOptions *RequestOptions, boardID int, locked bool, stickied bool, anchored bool, cyclic bool) (threadID int, err error) { const lockedQuery = `SELECT locked FROM DBPREFIXboards WHERE id = ?` const insertQuery = `INSERT INTO DBPREFIXthreads (board_id, locked, stickied, anchored, cyclical) VALUES (?,?,?,?,?)` var boardIsLocked bool - if err = QueryRowTxSQL(tx, lockedQuery, []any{boardID}, []any{&boardIsLocked}); err != nil { + if err = QueryRow(requestOptions, lockedQuery, []any{boardID}, []any{&boardIsLocked}); err != nil { return 0, err } if boardIsLocked { return 0, ErrBoardIsLocked } - if _, err = ExecTxSQL(tx, insertQuery, boardID, locked, stickied, anchored, cyclic); err != nil { + if _, err = Exec(requestOptions, insertQuery, boardID, locked, stickied, anchored, cyclic); err != nil { return 0, err } - return threadID, QueryRowTxSQL(tx, "SELECT MAX(id) FROM DBPREFIXthreads", nil, []any{&threadID}) + return threadID, QueryRow(requestOptions, "SELECT MAX(id) FROM DBPREFIXthreads", nil, []any{&threadID}) } // GetThread returns a a thread object from the database, given its ID func GetThread(threadID int) (*Thread, error) { const query = selectThreadsBaseSQL + `WHERE id = ?` thread := new(Thread) - err := QueryRowSQL(query, []any{threadID}, []any{ + err := QueryRow(nil, query, []any{threadID}, []any{ &thread.ID, &thread.BoardID, &thread.Locked, &thread.Stickied, &thread.Anchored, &thread.Cyclic, &thread.LastBump, &thread.DeletedAt, &thread.IsDeleted, }) @@ -52,7 +51,7 @@ func GetThread(threadID int) (*Thread, error) { func GetPostThread(opID int) (*Thread, error) { const query = selectThreadsBaseSQL + `WHERE id = (SELECT thread_id FROM DBPREFIXposts WHERE id = ? LIMIT 1)` thread := new(Thread) - err := QueryRowSQL(query, []any{opID}, []any{ + err := QueryRow(nil, query, []any{opID}, []any{ &thread.ID, &thread.BoardID, &thread.Locked, &thread.Stickied, &thread.Anchored, &thread.Cyclic, &thread.LastBump, &thread.DeletedAt, &thread.IsDeleted, }) @@ -66,7 +65,7 @@ func GetPostThread(opID int) (*Thread, error) { func GetTopPostThreadID(opID int) (int, error) { const query = `SELECT thread_id FROM DBPREFIXposts WHERE id = ? and is_top_post` var threadID int - err := QueryRowSQL(query, []any{opID}, []any{&threadID}) + err := QueryRow(nil, query, []any{opID}, []any{&threadID}) if err == sql.ErrNoRows { err = ErrThreadDoesNotExist } @@ -81,7 +80,7 @@ func GetThreadsWithBoardID(boardID int, onlyNotDeleted bool) ([]Thread, error) { if onlyNotDeleted { query += " AND is_deleted = FALSE" } - rows, err := QuerySQL(query, boardID) + rows, err := Query(nil, query, boardID) if err != nil { return nil, err } @@ -104,7 +103,7 @@ func GetThreadReplyCountFromOP(opID int) (int, error) { const query = `SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ( SELECT thread_id FROM DBPREFIXposts WHERE id = ?) AND is_deleted = FALSE AND is_top_post = FALSE` var num int - err := QueryRowSQL(query, []any{opID}, []any{&num}) + err := QueryRow(nil, query, []any{opID}, []any{&num}) return num, err } @@ -113,7 +112,7 @@ func ChangeThreadBoardID(threadID int, newBoardID int) error { if !DoesBoardExistByID(newBoardID) { return ErrBoardDoesNotExist } - _, err := ExecSQL(`UPDATE DBPREFIXthreads SET board_id = ? WHERE id = ?`, newBoardID, threadID) + _, err := Exec(nil, "UPDATE DBPREFIXthreads SET board_id = ? WHERE id = ?", newBoardID, threadID) return err } @@ -135,15 +134,15 @@ func (t *Thread) GetReplyFileCount() (int, error) { const query = `SELECT COUNT(filename) FROM DBPREFIXfiles WHERE post_id IN ( SELECT id FROM DBPREFIXposts WHERE thread_id = ? AND is_deleted = FALSE)` var fileCount int - err := QueryRowSQL(query, []any{t.ID}, []any{&fileCount}) + err := QueryRow(nil, query, []any{t.ID}, []any{&fileCount}) return fileCount, err } // GetReplyCount returns the number of posts in the thread, not including the top post or any deleted posts func (t *Thread) GetReplyCount() (int, error) { - const query = `SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ? AND is_top_post = FALSE AND is_deleted = FALSE` + const query = "SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ? AND is_top_post = FALSE AND is_deleted = FALSE" var numReplies int - err := QueryRowSQL(query, []any{t.ID}, []any{&numReplies}) + err := QueryRow(nil, query, []any{t.ID}, []any{&numReplies}) return numReplies, err } @@ -161,7 +160,7 @@ func (t *Thread) GetPosts(repliesOnly bool, boardPage bool, limit int) ([]Post, query += " LIMIT " + strconv.Itoa(limit) } - rows, err := QuerySQL(query, t.ID) + rows, err := Query(nil, query, t.ID) if err != nil { return nil, err } @@ -185,7 +184,7 @@ func (t *Thread) GetPosts(repliesOnly bool, boardPage bool, limit int) ([]Post, func (t *Thread) GetUploads() ([]Upload, error) { const query = selectFilesBaseSQL + ` WHERE post_id IN ( SELECT id FROM DBPREFIXposts WHERE thread_id = ? and is_deleted = FALSE) AND filename != 'deleted'` - rows, err := QuerySQL(query, t.ID) + rows, err := Query(nil, query, t.ID) if err != nil { return nil, err } @@ -223,18 +222,18 @@ func (t *Thread) UpdateAttribute(attribute string, value bool) error { return fmt.Errorf("invalid thread attribute %q", attribute) } updateSQL += attribute + " = ? WHERE id = ?" - _, err := ExecSQL(updateSQL, value, t.ID) + _, err := Exec(nil, updateSQL, value, t.ID) return err } // deleteThread updates the thread and sets it as deleted, as well as the posts where thread_id = threadID -func deleteThread(ctx context.Context, tx *sql.Tx, threadID int) error { +func deleteThread(opts *RequestOptions, threadID int) error { const checkPostExistsSQL = `SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ?` const deletePostsSQL = `UPDATE DBPREFIXposts SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE thread_id = ?` const deleteThreadSQL = `UPDATE DBPREFIXthreads SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = ?` var rowCount int - err := QueryRowContextSQL(ctx, tx, checkPostExistsSQL, []any{threadID}, []any{&rowCount}) + err := QueryRow(opts, checkPostExistsSQL, []any{threadID}, []any{&rowCount}) if err != nil { return err } @@ -242,10 +241,10 @@ func deleteThread(ctx context.Context, tx *sql.Tx, threadID int) error { return ErrThreadDoesNotExist } - _, err = ExecContextSQL(ctx, tx, deletePostsSQL, threadID) + _, err = Exec(opts, deletePostsSQL, threadID) if err != nil { return err } - _, err = ExecContextSQL(ctx, tx, deleteThreadSQL, threadID) + _, err = Exec(opts, deleteThreadSQL, threadID) return err } diff --git a/pkg/gcsql/uploads.go b/pkg/gcsql/uploads.go index c6c8202b..9629889d 100644 --- a/pkg/gcsql/uploads.go +++ b/pkg/gcsql/uploads.go @@ -1,7 +1,6 @@ package gcsql import ( - "context" "database/sql" "errors" "fmt" @@ -25,7 +24,7 @@ func GetThreadFiles(post *Post) ([]Upload, error) { query := selectFilesBaseSQL + `WHERE post_id IN ( SELECT id FROM DBPREFIXposts WHERE thread_id = ( SELECT thread_id FROM DBPREFIXposts WHERE id = ?)) AND filename != 'deleted'` - rows, err := QuerySQL(query, post.ID) + rows, err := Query(nil, query, post.ID) if err != nil { return nil, err } @@ -45,17 +44,28 @@ func GetThreadFiles(post *Post) ([]Upload, error) { } // NextFileOrder gets what would be the next file_order value (not particularly useful until multi-file posting is implemented) -func (p *Post) NextFileOrder(ctx context.Context, tx *sql.Tx) (int, error) { +func (p *Post) NextFileOrder(requestOpts ...*RequestOptions) (int, error) { + opts := setupOptions(requestOpts...) const query = `SELECT COALESCE(MAX(file_order) + 1, 0) FROM DBPREFIXfiles WHERE post_id = ?` var next int - err := QueryRowContextSQL(ctx, tx, query, []any{p.ID}, []any{&next}) + err := QueryRow(opts, query, []any{p.ID}, []any{&next}) return next, err } -func (p *Post) AttachFileTx(tx *sql.Tx, upload *Upload) error { +func (p *Post) AttachFile(upload *Upload, requestOpts ...*RequestOptions) error { if upload == nil { return nil // no upload to attach, so no error } + opts := setupOptions(requestOpts...) + shouldCommit := opts.Tx == nil + var err error + if shouldCommit { + opts.Tx, err = BeginTx() + if err != nil { + return err + } + defer opts.Tx.Rollback() + } _, err, recovered := events.TriggerEvent("incoming-upload", upload) if recovered { @@ -72,35 +82,30 @@ func (p *Post) AttachFileTx(tx *sql.Tx, upload *Upload) error { if upload.ID > 0 { return ErrAlreadyAttached } - if upload.FileOrder < 1 { - upload.FileOrder, err = p.NextFileOrder(context.Background(), tx) + upload.FileOrder, err = p.NextFileOrder(opts) if err != nil { return err } } upload.PostID = p.ID - if _, err = ExecTxSQL(tx, insertSQL, + if _, err = Exec(opts, insertSQL, &upload.PostID, &upload.FileOrder, &upload.OriginalFilename, &upload.Filename, &upload.Checksum, &upload.FileSize, &upload.IsSpoilered, &upload.ThumbnailWidth, &upload.ThumbnailHeight, &upload.Width, &upload.Height, ); err != nil { return err } - upload.ID, err = getLatestID("DBPREFIXfiles", tx) - return err -} - -func (p *Post) AttachFile(upload *Upload) error { - tx, err := BeginTx() + upload.ID, err = getLatestID(opts, "DBPREFIXfiles") if err != nil { return err } - defer tx.Rollback() - if err = p.AttachFileTx(tx, upload); err != nil { - return err + if shouldCommit { + if err = opts.Tx.Commit(); err != nil { + return err + } } - return tx.Commit() + return nil } // GetUploadFilenameAndBoard returns the filename (or an empty string) and @@ -112,7 +117,7 @@ func GetUploadFilenameAndBoard(postID int) (string, string, error) { JOIN DBPREFIXboards ON DBPREFIXboards.id = board_id WHERE DBPREFIXposts.id = ?` var filename, dir string - err := QueryRowSQL(query, []any{postID}, []any{&filename, &dir}) + err := QueryRow(nil, query, []any{postID}, []any{&filename, &dir}) if errors.Is(err, sql.ErrNoRows) { return "", "", nil } else if err != nil { diff --git a/pkg/gcsql/util.go b/pkg/gcsql/util.go index 872b286d..a82ad968 100644 --- a/pkg/gcsql/util.go +++ b/pkg/gcsql/util.go @@ -51,6 +51,39 @@ type intOrStringConstraint interface { int | string } +// RequestOptions is used to pass an optional context, transaction, and any other things to the various SQL functions +// in a future-proof way +type RequestOptions struct { + Context context.Context + Tx *sql.Tx + Cancel context.CancelFunc +} + +func setupOptions(opts ...*RequestOptions) *RequestOptions { + if len(opts) == 0 || opts[0] == nil { + return &RequestOptions{Context: context.Background()} + } + if opts[0].Context == nil { + opts[0].Context = context.Background() + } + return opts[0] +} + +// Query is a wrapper for QueryContextSQL that uses the given options, or defaults to a background context if nil +func Query(opts *RequestOptions, query string, a ...any) (*sql.Rows, error) { + return gcdb.Query(opts, query, a...) +} + +// QueryRow is a wrapper for QueryRowContextSQL that uses the given options, or defaults to a background context if nil +func QueryRow(opts *RequestOptions, query string, values, out []any) error { + return gcdb.QueryRow(opts, query, values, out) +} + +// Exec is a wrapper for ExecContextSQL that uses the given options, or defaults to a background context if nil +func Exec(opts *RequestOptions, query string, values ...any) (sql.Result, error) { + return gcdb.Exec(opts, query, values...) +} + // BeginTx begins a new transaction for the gochan database. It uses a background context func BeginTx() (*sql.Tx, error) { return BeginContextTx(context.Background()) @@ -111,7 +144,7 @@ func SetupSQLString(query string, dbConn *GCDB) (string, error) { return prepared, err } -// Close closes the connection to the SQL database +// Close closes the connection to the SQL database if it is open func Close() error { if gcdb != nil { return gcdb.Close() @@ -121,6 +154,7 @@ func Close() error { /* ExecSQL executes the given SQL statement with the given parameters + Example: var intVal int @@ -136,8 +170,8 @@ func ExecSQL(query string, values ...any) (sql.Result, error) { } /* -ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil) - +ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil). +Deprecated: Use Exec instead Example: ctx, cancel := context.WithTimeout(context.Background(), time.Duration(sqlCfg.DBTimeoutSeconds) * time.Second) @@ -154,6 +188,7 @@ func ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...an return gcdb.ExecContextSQL(ctx, tx, sqlStr, values...) } +// ExecTimeoutSQL is a helper function for executing a SQL statement with the configured timeout in seconds func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) { ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout) defer cancel() @@ -162,7 +197,9 @@ func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error } /* -ExecTxSQL automatically escapes the given values and caches the statement +ExecTxSQL executes the given SQL statement with the given transaction and parameters. +Deprecated: Use Exec instead with a transaction in the RequestOptions + Example: tx, err := BeginTx() @@ -190,8 +227,8 @@ func ExecTxSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) { } /* -QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[] -Automatically escapes the given values and caches the query +QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]. +Deprecated: Use QueryRow instead Example: @@ -211,7 +248,8 @@ func QueryRowSQL(query string, values, out []any) error { /* QueryRowContextSQL gets a row from the database with the values in values[] and fills the respective pointers in out[] -using the given context as a deadline, and the given transaction (if non-nil) +using the given context as a deadline, and the given transaction (if non-nil). +Deprecated: Use QueryRow instead with an optional context and/or tx in the RequestOptions Example: @@ -239,8 +277,8 @@ func QueryRowTimeoutSQL(tx *sql.Tx, query string, values, out []any) error { } /* -QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[] -Automatically escapes the given values and caches the query +QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]. +Deprecated: Use QueryRow instead with a transaction in the RequestOptions Example: @@ -261,8 +299,8 @@ func QueryRowTxSQL(tx *sql.Tx, query string, values, out []any) error { } /* -QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[] -Automatically escapes the given values and caches the query +QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]. +Deprecated: Use Query instead Example: @@ -285,7 +323,8 @@ func QuerySQL(query string, a ...any) (*sql.Rows, error) { /* QueryContextSQL queries the database with a prepared statement and the given parameters, using the given context -for a deadline +for a deadline. +Deprecated: Use Query instead with an optional context/transaction in the RequestOptions Example: @@ -317,7 +356,9 @@ func QueryTimeoutSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, context.Can /* QueryTxSQL gets all rows from the db using the transaction tx with the values in values[] and fills the -respective pointers in out[]. Automatically escapes the given values and caches the query +respective pointers in out[]. +Deprecated: Use Query instead with a transaction in the RequestOptions + Example: tx, err := BeginTx() @@ -347,6 +388,7 @@ func QueryTxSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, error) { return rows, stmt.Close() } +// ParseSQLTimeString attempts to parse a string into a time.Time object using the known SQL date/time formats func ParseSQLTimeString(str string) (time.Time, error) { var t time.Time var err error @@ -359,22 +401,10 @@ func ParseSQLTimeString(str string) (time.Time, error) { } // getLatestID returns the latest inserted id column value from the given table -func getLatestID(tableName string, tx *sql.Tx) (id int, err error) { +func getLatestID(opts *RequestOptions, tableName string) (id int, err error) { + opts = setupOptions(opts) query := `SELECT MAX(id) FROM ` + tableName - if tx != nil { - var stmt *sql.Stmt - stmt, err = PrepareSQL(query, tx) - if err != nil { - return 0, err - } - defer stmt.Close() - if err = stmt.QueryRow().Scan(&id); err != nil { - return - } - err = stmt.Close() - } else { - err = QueryRowSQL(query, nil, []any{&id}) - } + QueryRow(opts, query, nil, []any{&id}) return } @@ -393,7 +423,7 @@ func doesTableExist(tableName string) (bool, error) { } var count int - err := QueryRowSQL(existQuery, []any{config.GetSystemCriticalConfig().DBprefix + tableName}, []any{&count}) + err := QueryRow(nil, existQuery, []any{config.GetSystemCriticalConfig().DBprefix + tableName}, []any{&count}) if err != nil { return false, err } @@ -404,7 +434,7 @@ func doesTableExist(tableName string) (bool, error) { func GetComponentVersion(componentKey string) (int, error) { const sql = `SELECT version FROM DBPREFIXdatabase_version WHERE component = ?` var version int - err := QueryRowSQL(sql, []any{componentKey}, []any{&version}) + err := QueryRow(nil, sql, []any{componentKey}, []any{&version}) return version, err } @@ -434,7 +464,7 @@ func doesGochanPrefixTableExist() (bool, error) { } var count int - err := QueryRowSQL(prefixTableExist, []any{}, []any{&count}) + err := QueryRow(nil, prefixTableExist, []any{}, []any{&count}) if err != nil && err != sql.ErrNoRows { return false, err }