1
0
Fork 0
mirror of https://github.com/Eggbertx/gochan.git synced 2025-08-18 11:46:23 -07:00

Add future-proof functions using struct for context, tx, etc

This commit is contained in:
Eggbertx 2025-02-05 17:32:10 -08:00
parent a38a519e4e
commit 10e0da4492
13 changed files with 332 additions and 286 deletions

View file

@ -80,7 +80,7 @@ func (m *Pre2021Migrator) migrateBan(tx *sql.Tx, ban *migrationBan, boardID *int
migratedBan.Message = ban.reason
migratedBan.StaffID = ban.staffID
migratedBan.StaffNote = ban.staffNote
if err := gcsql.NewIPBanTx(tx, migratedBan); err != nil {
if err := gcsql.NewIPBan(migratedBan, &gcsql.RequestOptions{Tx: tx}); err != nil {
errEv.Err(err).Caller().
Int("oldID", ban.oldID).Msg("Failed to migrate ban")
return err

View file

@ -37,7 +37,7 @@ func (m *Pre2021Migrator) migrateSections() error {
}
var sectionsToBeCreated []gcsql.Section
rows, err := m.db.QuerySQL(sectionsQuery)
rows, err := m.db.Query(nil, sectionsQuery)
if err != nil {
errEv.Err(err).Caller().Msg("Failed to query old database sections")
return err
@ -126,7 +126,7 @@ func (m *Pre2021Migrator) MigrateBoards() error {
}
// get boards from old db
rows, err := m.db.QuerySQL(boardsQuery)
rows, err := m.db.Query(nil, boardsQuery)
if err != nil {
errEv.Err(err).Caller().Msg("Failed to query old database boards")
return err
@ -162,7 +162,7 @@ func (m *Pre2021Migrator) MigrateBoards() error {
Int("migratedBoardID", newBoard.ID).
Msg("Board already exists in new db, updating values")
// don't update other values in the array since they don't affect migrating threads or posts
if _, err = gcsql.ExecSQL(`UPDATE DBPREFIXboards
if _, err = gcsql.Exec(nil, `UPDATE DBPREFIXboards
SET uri = ?, navbar_position = ?, title = ?, subtitle = ?, description = ?,
max_file_size = ?, max_threads = ?, default_style = ?, locked = ?,
anonymous_name = ?, force_anonymous = ?, autosage_after = ?, no_images_after = ?, max_message_length = ?,

View file

@ -1,7 +1,6 @@
package pre2021
import (
"context"
"database/sql"
"time"
@ -35,10 +34,10 @@ type migrationPost struct {
func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *zerolog.Event) error {
var err error
opts := &gcsql.RequestOptions{Tx: tx}
if post.oldParentID == 0 {
// migrating post was a thread OP, create the row in the threads table
if post.ThreadID, err = gcsql.CreateThread(tx, post.boardID, false, post.stickied, post.autosage, false); err != nil {
if post.ThreadID, err = gcsql.CreateThread(opts, post.boardID, false, post.stickied, post.autosage, false); err != nil {
errEv.Err(err).Caller().
Int("boardID", post.boardID).
Msg("Failed to create thread")
@ -46,7 +45,7 @@ func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *ze
}
// insert thread top post
if err = post.InsertWithContext(context.Background(), tx, true, post.boardID, false, post.stickied, post.autosage, false); err != nil {
if err = post.Insert(true, post.boardID, false, post.stickied, post.autosage, false, opts); err != nil {
errEv.Err(err).Caller().
Int("boardID", post.boardID).
Int("threadID", post.ThreadID).
@ -54,7 +53,7 @@ func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *ze
}
if post.filename != "" {
if err = post.AttachFileTx(tx, &gcsql.Upload{
if err = post.AttachFile(&gcsql.Upload{
PostID: post.ID,
OriginalFilename: post.filenameOriginal,
Filename: post.filename,
@ -64,7 +63,7 @@ func (m *Pre2021Migrator) migratePost(tx *sql.Tx, post *migrationPost, errEv *ze
ThumbnailHeight: post.thumbH,
Width: post.imageW,
Height: post.imageH,
}); err != nil {
}, opts); err != nil {
errEv.Err(err).Caller().
Int("oldPostID", post.oldID).
Msg("Failed to attach upload to migrated post")
@ -85,7 +84,7 @@ func (m *Pre2021Migrator) MigratePosts() error {
}
defer tx.Rollback()
rows, err := m.db.QuerySQL(threadsQuery)
rows, err := m.db.Query(nil, threadsQuery)
if err != nil {
errEv.Err(err).Caller().Msg("Failed to get threads")
return err
@ -126,7 +125,7 @@ func (m *Pre2021Migrator) MigratePosts() error {
}
// get and insert replies
replyRows, err := m.db.QuerySQL(postsQuery+" AND parentid = ?", thread.oldID)
replyRows, err := m.db.Query(nil, postsQuery+" AND parentid = ?", thread.oldID)
if err != nil {
errEv.Err(err).Caller().
Int("parentID", thread.oldID).
@ -156,7 +155,7 @@ func (m *Pre2021Migrator) MigratePosts() error {
}
if thread.locked {
if _, err = gcsql.ExecTxSQL(tx, "UPDATE DBPREFIXthreads SET locked = TRUE WHERE id = ?", thread.ThreadID); err != nil {
if _, err = gcsql.Exec(&gcsql.RequestOptions{Tx: tx}, "UPDATE DBPREFIXthreads SET locked = TRUE WHERE id = ?", thread.ThreadID); err != nil {
errEv.Err(err).Caller().
Int("threadID", thread.ThreadID).
Msg("Unable to re-lock migrated thread")

View file

@ -25,38 +25,42 @@ type Ban interface {
Deactivate(int) error
}
func NewIPBanTx(tx *sql.Tx, ban *IPBan) error {
func NewIPBan(ban *IPBan, requestOpts ...*RequestOptions) error {
const query = `INSERT INTO DBPREFIXip_ban
(staff_id, board_id, banned_for_post_id, copy_post_text, is_thread_ban,
is_active, range_start, range_end, appeal_at, expires_at,
permanent, staff_note, message, can_appeal)
VALUES(?, ?, ?, ?, ?, ?, PARAM_ATON, PARAM_ATON, ?, ?, ?, ?, ?, ?)`
opts := setupOptions(requestOpts...)
shouldCommit := opts.Tx == nil
var err error
if shouldCommit {
opts.Tx, err = BeginTx()
if err != nil {
return err
}
defer opts.Tx.Rollback()
}
if ban.ID > 0 {
return ErrBanAlreadyInserted
}
_, err := ExecTxSQL(tx, query, ban.StaffID, ban.BoardID, ban.BannedForPostID, ban.CopyPostText,
if _, err = Exec(opts, query, ban.StaffID, ban.BoardID, ban.BannedForPostID, ban.CopyPostText,
ban.IsThreadBan, ban.IsActive, ban.RangeStart, ban.RangeEnd, ban.AppealAt,
ban.ExpiresAt, ban.Permanent, ban.StaffNote, ban.Message, ban.CanAppeal)
ban.ExpiresAt, ban.Permanent, ban.StaffNote, ban.Message, ban.CanAppeal,
); err != nil {
return err
}
ban.ID, err = getLatestID(opts, "DBPREFIXip_ban")
if err != nil {
return err
}
ban.ID, err = getLatestID("DBPREFIXip_ban", tx)
return err
}
func NewIPBan(ban *IPBan) error {
tx, err := BeginTx()
if err != nil {
return err
}
defer tx.Rollback()
if err = NewIPBanTx(tx, ban); err != nil {
return err
if shouldCommit {
return opts.Tx.Commit()
}
return tx.Commit()
return nil
}
// CheckIPBan returns the latest active IP ban for the given IP, as well as any
@ -74,7 +78,7 @@ func CheckIPBan(ip string, boardID int) (*IPBan, error) {
(expires_at > CURRENT_TIMESTAMP OR permanent)
ORDER BY id DESC LIMIT 1`
var ban IPBan
err := QueryRowSQL(query, []any{ip, ip, boardID}, []any{
err := QueryRow(nil, query, []any{ip, ip, boardID}, []any{
&ban.ID, &ban.StaffID, &ban.BoardID, &ban.BannedForPostID, &ban.CopyPostText,
&ban.IsThreadBan, &ban.IsActive, &ban.RangeStart, &ban.RangeEnd, &ban.IssuedAt,
&ban.AppealAt, &ban.ExpiresAt, &ban.Permanent, &ban.StaffNote, &ban.Message,
@ -90,7 +94,7 @@ func CheckIPBan(ip string, boardID int) (*IPBan, error) {
func GetIPBanByID(id int) (*IPBan, error) {
const query = ipBanQueryBase + " WHERE id = ?"
var ban IPBan
err := QueryRowSQL(query, []any{id}, []any{
err := QueryRow(nil, query, []any{id}, []any{
&ban.ID, &ban.StaffID, &ban.BoardID, &ban.BannedForPostID, &ban.CopyPostText,
&ban.IsThreadBan, &ban.IsActive, &ban.RangeStart, &ban.RangeEnd, &ban.IssuedAt,
&ban.AppealAt, &ban.ExpiresAt, &ban.Permanent, &ban.StaffNote, &ban.Message,
@ -110,14 +114,15 @@ func GetIPBans(boardID int, limit int, onlyActive bool) ([]IPBan, error) {
var rows *sql.Rows
var err error
if boardID > 0 {
rows, err = QuerySQL(query, boardID)
rows, err = Query(nil, query, boardID)
} else {
rows, err = QuerySQL(query)
rows, err = Query(nil, query)
}
if err != nil {
return nil, err
}
var bans []IPBan
defer rows.Close()
for rows.Next() {
var ban IPBan
if err = rows.Scan(
@ -125,7 +130,6 @@ func GetIPBans(boardID int, limit int, onlyActive bool) ([]IPBan, error) {
&ban.IsActive, &ban.RangeStart, &ban.RangeEnd, &ban.IssuedAt, &ban.AppealAt, &ban.ExpiresAt,
&ban.Permanent, &ban.StaffNote, &ban.Message, &ban.CanAppeal,
); err != nil {
rows.Close()
return nil, err
}
if onlyActive && !ban.IsActive {
@ -138,7 +142,7 @@ func GetIPBans(boardID int, limit int, onlyActive bool) ([]IPBan, error) {
func (ipb *IPBan) Appeal(msg string) error {
const query = `INSERT INTO DBPREFIXip_ban_appeals (ip_ban_id, appeal_text, is_denied) VALUES(?, ?, FALSE)`
_, err := ExecSQL(query, ipb.ID, msg)
_, err := Exec(nil, query, ipb.ID, msg)
return err
}
@ -163,10 +167,10 @@ func (ipb *IPBan) Deactivate(_ int) error {
return err
}
defer tx.Rollback()
if _, err = ExecTxSQL(tx, deactivateQuery, ipb.ID); err != nil {
if _, err = Exec(&RequestOptions{Tx: tx}, deactivateQuery, ipb.ID); err != nil {
return err
}
if _, err = ExecTxSQL(tx, auditInsertQuery, ipb.ID); err != nil {
if _, err = Exec(&RequestOptions{Tx: tx}, auditInsertQuery, ipb.ID); err != nil {
return err
}
return tx.Commit()

View file

@ -459,7 +459,7 @@ func (board *Board) ModifyInDB() error {
require_file = ?,
enable_catalog = ?
WHERE id = ?`
_, err := ExecSQL(query,
_, err := Exec(nil, query,
board.SectionID, board.NavbarPosition, board.Title, board.Subtitle, board.Description,
board.MaxFilesize, board.MaxThreads, board.DefaultStyle, board.Locked, board.AnonymousName,
board.ForceAnonymous, board.AutosageAfter, board.NoImagesAfter, board.MaxMessageLength,

View file

@ -118,21 +118,16 @@ func (db *GCDB) PrepareContextSQL(ctx context.Context, query string, tx *sql.Tx)
return stmt, sqlVersionError(err, db.driver, &prepared)
}
/*
ExecSQL executes the given SQL statement with the given parameters
Example:
var intVal int
var stringVal string
result, err := db.ExecSQL("INSERT INTO tablename (intval,stringval) VALUES(?,?)", intVal, stringVal)
*/
func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) {
stmt, err := db.PrepareSQL(query, nil)
// Exec executes the given SQL statement with the given parameters, optionally with the given RequestOptions struct
// or a background context and transaction if nil
func (db *GCDB) Exec(opts *RequestOptions, query string, values ...any) (sql.Result, error) {
opts = setupOptions(opts)
stmt, err := db.PrepareContextSQL(opts.Context, query, opts.Tx)
if err != nil {
return nil, err
}
defer stmt.Close()
result, err := stmt.Exec(values...)
result, err := stmt.ExecContext(opts.Context, values...)
if err != nil {
return nil, err
}
@ -140,7 +135,22 @@ func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) {
}
/*
ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil)
ExecSQL executes the given SQL statement with the given parameters.
Deprecated: Use Exec instead
Example:
var intVal int
var stringVal string
result, err := db.ExecSQL("INSERT INTO tablename (intval,stringval) VALUES(?,?)", intVal, stringVal)
*/
func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) {
return db.Exec(nil, query, values...)
}
/*
ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil).
Deprecated: Use Exec instead, with a RequestOptions struct for the context and transaction
Example:
@ -152,21 +162,12 @@ Example:
intVal, stringVal)
*/
func (db *GCDB) ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
stmt, err := db.PrepareContextSQL(ctx, sqlStr, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
result, err := stmt.ExecContext(ctx, values...)
if err != nil {
return nil, err
}
return result, stmt.Close()
return db.Exec(&RequestOptions{Context: ctx, Tx: tx}, sqlStr, values...)
}
/*
ExecTxSQL executes the given SQL statemtnt, optionally with the given transaction (if non-nil)
ExecTxSQL executes the given SQL statemtnt, optionally with the given transaction (if non-nil).
Deprecated: Use Exec instead, with a RequestOptions struct for the transaction
Example:
@ -179,16 +180,7 @@ Example:
intVal, stringVal)
*/
func (db *GCDB) ExecTxSQL(tx *sql.Tx, query string, values ...any) (sql.Result, error) {
stmt, err := db.PrepareSQL(query, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
res, err := stmt.Exec(values...)
if err != nil {
return res, err
}
return res, stmt.Close()
return db.Exec(&RequestOptions{Tx: tx}, query, values...)
}
/*
@ -209,9 +201,25 @@ func (db *GCDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, erro
return db.db.BeginTx(ctx, opts)
}
// QueryRow gets a row from the db with the values in values[] and fills the respective pointers in out[],
// with an optional RequestOptions struct for the context and transaction
func (db *GCDB) QueryRow(opts *RequestOptions, query string, values []any, out []any) error {
opts = setupOptions(opts)
stmt, err := db.PrepareContextSQL(opts.Context, query, opts.Tx)
if err != nil {
return err
}
defer stmt.Close()
if err = stmt.QueryRowContext(opts.Context, values...).Scan(out...); err != nil {
return err
}
return stmt.Close()
}
/*
QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[].
Deprecated: Use QueryRow instead
Example:
id := 32
@ -222,12 +230,7 @@ Example:
[]any{&intVal, &stringVal})
*/
func (db *GCDB) QueryRowSQL(query string, values, out []any) error {
stmt, err := db.PrepareSQL(query, nil)
if err != nil {
return err
}
defer stmt.Close()
return stmt.QueryRow(values...).Scan(out...)
return db.QueryRow(nil, query, values, out)
}
/*
@ -244,16 +247,7 @@ Example:
[]any{id}, []any{&name})
*/
func (db *GCDB) QueryRowContextSQL(ctx context.Context, tx *sql.Tx, query string, values, out []any) error {
stmt, err := db.PrepareContextSQL(ctx, query, tx)
if err != nil {
return err
}
defer stmt.Close()
if err = stmt.QueryRowContext(ctx, values...).Scan(out...); err != nil {
return err
}
return stmt.Close()
return db.QueryRow(&RequestOptions{Context: ctx, Tx: tx}, query, values, out)
}
/*
@ -273,21 +267,29 @@ Example:
[]any{&intVal, &stringVal})
*/
func (db *GCDB) QueryRowTxSQL(tx *sql.Tx, query string, values, out []any) error {
stmt, err := db.PrepareSQL(query, tx)
return db.QueryRow(&RequestOptions{Tx: tx}, query, values, out)
}
// Query sends the query to the database with the given options (or a background context if nil), and the given parameters
func (db *GCDB) Query(opts *RequestOptions, query string, a ...any) (*sql.Rows, error) {
opts = setupOptions(opts)
stmt, err := db.PrepareContextSQL(opts.Context, query, opts.Tx)
if err != nil {
return err
return nil, err
}
defer stmt.Close()
if err = stmt.QueryRow(values...).Scan(out...); err != nil {
return err
rows, err := stmt.QueryContext(opts.Context, a...)
if err != nil {
return rows, err
}
return stmt.Close()
return rows, stmt.Close()
}
/*
QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[].
Deprecated: Use Query instead
Example:
rows, err := db.QuerySQL("SELECT * FROM table")
@ -301,17 +303,13 @@ Example:
}
*/
func (db *GCDB) QuerySQL(query string, a ...any) (*sql.Rows, error) {
stmt, err := db.PrepareSQL(query, nil)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(a...)
return db.Query(nil, query, a...)
}
/*
QueryContextSQL queries the database with a prepared statement and the given parameters, using the given context
for a deadline
for a deadline.
Deprecated: Use Query instead, with a RequestOptions struct for the context and transaction
Example:
@ -320,22 +318,12 @@ Example:
rows, err := db.QueryContextSQL(ctx, nil, "SELECT name from posts where NOT is_deleted")
*/
func (db *GCDB) QueryContextSQL(ctx context.Context, tx *sql.Tx, query string, a ...any) (*sql.Rows, error) {
stmt, err := db.PrepareContextSQL(ctx, query, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx, a...)
if err != nil {
return rows, err
}
return rows, stmt.Close()
return db.Query(&RequestOptions{Context: ctx, Tx: tx}, query, a...)
}
/*
QueryTxSQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
QueryTxSQL gets all rows from the db with the values in values[] and fills the respective pointers in out[].
Deprecated: Use Query instead, with a RequestOptions struct for the transaction
Example:
tx, _ := db.Begin()
@ -350,16 +338,7 @@ Example:
}
*/
func (db *GCDB) QueryTxSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, error) {
stmt, err := db.PrepareSQL(query, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(a...)
if err != nil {
return nil, err
}
return rows, stmt.Close()
return db.Query(&RequestOptions{Tx: tx}, query, a...)
}
func setupDBConn(cfg *config.SQLConfig) (db *GCDB, err error) {
@ -404,6 +383,7 @@ func setupSqlTestConfig(dbDriver string, dbName string, dbPrefix string) *config
}
}
// SetupMockDB sets up a mock database connection for testing
func SetupMockDB(driver string) (sqlmock.Sqlmock, error) {
var err error
gcdb, err = setupDBConn(setupSqlTestConfig(driver, "gochan", ""))

View file

@ -39,23 +39,24 @@ func GetPostFromID(id int, onlyNotDeleted bool) (*Post, error) {
query += " AND is_deleted = FALSE"
}
post := new(Post)
err := QueryRowSQL(query, []any{id}, []any{
err := QueryRow(nil, query, []any{id}, []any{
&post.ID, &post.ThreadID, &post.IsTopPost, &post.IP, &post.CreatedOn, &post.Name,
&post.Tripcode, &post.IsRoleSignature, &post.Email, &post.Subject, &post.Message,
&post.MessageRaw, &post.Password, &post.DeletedAt, &post.IsDeleted,
&post.BannedMessage, &post.Flag, &post.Country,
})
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrPostDoesNotExist
} else if err != nil {
return nil, err
}
return post, err
return post, nil
}
func GetPostIP(postID int) (string, error) {
sql := "SELECT IP_NTOA FROM DBPREFIXposts WHERE id = ?"
var ip string
err := QueryRowSQL(sql, []any{postID}, []any{&ip})
err := QueryRow(nil, sql, []any{postID}, []any{&ip})
return ip, err
}
@ -68,7 +69,7 @@ func GetPostsFromIP(ip string, limit int, onlyNotDeleted bool) ([]Post, error) {
}
sql += " ORDER BY id DESC LIMIT ?"
rows, err := QuerySQL(sql, ip, limit)
rows, err := Query(nil, sql, ip, limit)
if err != nil {
return nil, err
}
@ -109,7 +110,7 @@ func GetTopPostIDsInThreadIDs(threads ...any) (map[any]int, error) {
}
params := createArrayPlaceholder(threads)
query := `SELECT id FROM DBPREFIXposts WHERE thread_id in ` + params + " AND is_top_post"
rows, err := QuerySQL(query, threads...)
rows, err := Query(nil, query, threads...)
if err != nil {
return nil, err
}
@ -130,7 +131,7 @@ func GetTopPostIDsInThreadIDs(threads ...any) (map[any]int, error) {
func GetThreadTopPost(threadID int) (*Post, error) {
const query = selectPostsBaseSQL + "WHERE thread_id = ? AND is_top_post = TRUE LIMIT 1"
post := new(Post)
err := QueryRowSQL(query, []any{threadID}, []any{
err := QueryRow(nil, query, []any{threadID}, []any{
&post.ID, &post.ThreadID, &post.IsTopPost, &post.IP, &post.CreatedOn, &post.Name,
&post.Tripcode, &post.IsRoleSignature, &post.Email, &post.Subject, &post.Message,
&post.MessageRaw, &post.Password, &post.DeletedAt, &post.IsDeleted,
@ -180,19 +181,23 @@ func GetBoardTopPosts[B intOrStringConstraint](board B) ([]*Post, error) {
func GetPostPassword(id int) (string, error) {
const query = `SELECT password FROM DBPREFIXposts WHERE id = ?`
var passwordChecksum string
err := QueryRowSQL(query, []any{id}, []any{&passwordChecksum})
err := QueryRow(nil, query, []any{id}, []any{&passwordChecksum})
return passwordChecksum, err
}
// PermanentlyRemoveDeletedPosts removes all posts and files marked as deleted from the database
func PermanentlyRemoveDeletedPosts() error {
func PermanentlyRemoveDeletedPosts(opts ...*RequestOptions) error {
const sql1 = `DELETE FROM DBPREFIXposts WHERE is_deleted`
const sql2 = `DELETE FROM DBPREFIXthreads WHERE is_deleted`
_, err := ExecSQL(sql1)
var useOpts *RequestOptions
if len(opts) > 0 {
useOpts = opts[0]
}
_, err := Exec(useOpts, sql1)
if err != nil {
return err
}
_, err = ExecSQL(sql2)
_, err = Exec(useOpts, sql2)
return err
}
@ -201,7 +206,7 @@ func PermanentlyRemoveDeletedPosts() error {
func SinceLastPost(postIP string) (int, error) {
const query = `SELECT COALESCE(MAX(created_on), '1970-01-01 00:00:00') FROM DBPREFIXposts WHERE ip = ?`
var whenStr string
err := QueryRowSQL(query, []any{postIP}, []any{&whenStr})
err := QueryRow(nil, query, []any{postIP}, []any{&whenStr})
if err != nil {
return -1, err
}
@ -219,7 +224,7 @@ func SinceLastThread(postIP string) (int, error) {
const query = `SELECT COALESCE(MAX(created_on), '1970-01-01 00:00:00') FROM DBPREFIXposts WHERE ip = ? AND is_top_post`
var whenStr string
err := QueryRowSQL(query, []any{postIP}, []any{&whenStr})
err := QueryRow(nil, query, []any{postIP}, []any{&whenStr})
if err != nil {
return -1, err
}
@ -233,7 +238,7 @@ func SinceLastThread(postIP string) (int, error) {
// UpdateContents updates the email, subject, and message text of the post
func (p *Post) UpdateContents(email string, subject string, message template.HTML, messageRaw string) error {
const sqlUpdate = `UPDATE DBPREFIXposts SET email = ?, subject = ?, message = ?, message_raw = ? WHERE ID = ?`
_, err := ExecSQL(sqlUpdate, email, subject, message, messageRaw, p.ID)
_, err := Exec(nil, sqlUpdate, email, subject, message, messageRaw, p.ID)
if err != nil {
return err
}
@ -244,27 +249,27 @@ func (p *Post) UpdateContents(email string, subject string, message template.HTM
return nil
}
func (p *Post) GetBoardID() (int, error) {
func (p *Post) GetBoardID(opts ...*RequestOptions) (int, error) {
const query = `SELECT board_id FROM DBPREFIXthreads where id = ?`
var boardID int
err := QueryRowSQL(query, []any{p.ThreadID}, []any{&boardID})
err := QueryRow(setupOptions(opts...), query, []any{p.ThreadID}, []any{&boardID})
if errors.Is(err, sql.ErrNoRows) {
err = ErrBoardDoesNotExist
}
return boardID, err
}
func (p *Post) GetBoardDir() (string, error) {
func (p *Post) GetBoardDir(opts ...*RequestOptions) (string, error) {
const query = "SELECT dir FROM DBPREFIXboards" + boardFromPostIdSuffixSQL
var dir string
err := QueryRowSQL(query, []any{p.ID}, []any{&dir})
err := QueryRow(setupOptions(opts...), query, []any{p.ID}, []any{&dir})
return dir, err
}
func (p *Post) GetBoard() (*Board, error) {
func (p *Post) GetBoard(opts ...*RequestOptions) (*Board, error) {
const query = selectBoardsBaseSQL + boardFromPostIdSuffixSQL
board := new(Board)
err := QueryRowSQL(query, []any{p.ID}, []any{
err := QueryRow(setupOptions(opts...), query, []any{p.ID}, []any{
&board.ID, &board.SectionID, &board.URI, &board.Dir, &board.NavbarPosition, &board.Title, &board.Subtitle,
&board.Description, &board.MaxFilesize, &board.MaxThreads, &board.DefaultStyle, &board.Locked,
&board.CreatedAt, &board.AnonymousName, &board.ForceAnonymous, &board.AutosageAfter, &board.NoImagesAfter,
@ -284,13 +289,13 @@ func (p *Post) ChangeBoardID(newBoardID int) error {
}
// TopPostID returns the OP post ID of the thread that p is in
func (p *Post) TopPostID() (int, error) {
func (p *Post) TopPostID(opts ...*RequestOptions) (int, error) {
if p.IsTopPost {
return p.ID, nil
}
const query = `SELECT id FROM DBPREFIXposts WHERE thread_id = ? and is_top_post = TRUE ORDER BY id ASC LIMIT 1`
var topPostID int
err := QueryRowSQL(query, []any{p.ThreadID}, []any{&topPostID})
err := QueryRow(setupOptions(opts...), query, []any{p.ThreadID}, []any{&topPostID})
return topPostID, err
}
@ -306,13 +311,13 @@ func (p *Post) GetTopPost() (*Post, error) {
// GetPostUpload returns the upload info associated with the file as well as any errors encountered.
// If the file has no uploads, then *Upload is nil. If the file was removed from the post, then Filename
// and OriginalFilename = "deleted"
func (p *Post) GetUpload() (*Upload, error) {
func (p *Post) GetUpload(opts ...*RequestOptions) (*Upload, error) {
const query = `SELECT
id, post_id, file_order, original_filename, filename, checksum,
file_size, is_spoilered, thumbnail_width, thumbnail_height, width, height
FROM DBPREFIXfiles WHERE post_id = ?`
upload := new(Upload)
err := QueryRowSQL(query, []any{p.ID}, []any{
err := QueryRow(setupOptions(opts...), query, []any{p.ID}, []any{
&upload.ID, &upload.PostID, &upload.FileOrder, &upload.OriginalFilename, &upload.Filename, &upload.Checksum,
&upload.FileSize, &upload.IsSpoilered, &upload.ThumbnailWidth, &upload.ThumbnailHeight, &upload.Width, &upload.Height,
})
@ -325,7 +330,7 @@ func (p *Post) GetUpload() (*Upload, error) {
// UnlinkUploads disassociates the post with any uploads in DBPREFIXfiles
// that may have been uploaded with it, optionally leaving behind a "File Deleted"
// frame where the thumbnail appeared
func (p *Post) UnlinkUploads(leaveDeletedBox bool) error {
func (p *Post) UnlinkUploads(leaveDeletedBox bool, requestOpts ...*RequestOptions) error {
var sqlStr string
if leaveDeletedBox {
// leave a "File Deleted" box
@ -333,7 +338,7 @@ func (p *Post) UnlinkUploads(leaveDeletedBox bool) error {
} else {
sqlStr = `DELETE FROM DBPREFIXfiles WHERE post_id = ?`
}
_, err := ExecSQL(sqlStr, p.ID)
_, err := Exec(setupOptions(requestOpts...), sqlStr, p.ID)
return err
}
@ -348,17 +353,24 @@ func (p *Post) InCyclicThread() (bool, error) {
}
// Delete sets the post as deleted and sets the deleted_at timestamp to the current time
func (p *Post) Delete() error {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
tx, err := BeginContextTx(ctx)
if err != nil {
return err
func (p *Post) Delete(requestOptions ...*RequestOptions) error {
shouldCommit := len(requestOptions) == 0
opts := setupOptions(requestOptions...)
if opts.Context == context.Background() {
opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer opts.Cancel()
}
var err error
if opts.Tx == nil {
opts.Tx, err = BeginContextTx(opts.Context)
if err != nil {
return err
}
defer opts.Tx.Rollback()
}
defer tx.Rollback()
var rowCount int
err = QueryRowContextSQL(ctx, tx, "SELECT COUNT(*) FROM DBPREFIXposts WHERE id = ?", []any{p.ID}, []any{&rowCount})
err = QueryRow(opts, "SELECT COUNT(*) FROM DBPREFIXposts WHERE id = ?", []any{p.ID}, []any{&rowCount})
if errors.Is(err, sql.ErrNoRows) {
err = ErrPostDoesNotExist
}
@ -367,16 +379,33 @@ func (p *Post) Delete() error {
}
if p.IsTopPost {
return deleteThread(ctx, tx, p.ThreadID)
return deleteThread(opts, p.ThreadID)
}
if _, err = ExecContextSQL(ctx, tx, "UPDATE DBPREFIXposts SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = ?", p.ID); err != nil {
if _, err = Exec(opts, "UPDATE DBPREFIXposts SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = ?", p.ID); err != nil {
return err
}
return tx.Commit()
if shouldCommit {
return opts.Tx.Commit()
}
return nil
}
// InsertWithContext inserts the post into the database with the given context and transaction
func (p *Post) InsertWithContext(ctx context.Context, tx *sql.Tx, bumpThread bool, boardID int, locked bool, stickied bool, anchored bool, cyclical bool) error {
// Insert inserts the post into the database with the optional given options
func (p *Post) Insert(bumpThread bool, boardID int, locked bool, stickied bool, anchored bool, cyclical bool, requestOptions ...*RequestOptions) error {
opts := setupOptions(requestOptions...)
if len(requestOptions) == 0 {
opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer opts.Cancel()
}
var err error
if opts.Tx == nil {
opts.Tx, err = BeginContextTx(opts.Context)
if err != nil {
return err
}
defer opts.Tx.Rollback()
}
if p.ID > 0 {
// already inserted
return ErrorPostAlreadySent
@ -387,19 +416,18 @@ func (p *Post) InsertWithContext(ctx context.Context, tx *sql.Tx, bumpThread boo
VALUES(?,?,PARAM_ATON,CURRENT_TIMESTAMP,?,?,?,?,?,?,?,?,?,?)`
bumpSQL := `UPDATE DBPREFIXthreads SET last_bump = CURRENT_TIMESTAMP WHERE id = ?`
var err error
if p.ThreadID == 0 {
// thread doesn't exist yet, this is a new post
p.IsTopPost = true
var threadID int
threadID, err = CreateThread(tx, boardID, locked, stickied, anchored, cyclical)
threadID, err = CreateThread(opts, boardID, locked, stickied, anchored, cyclical)
if err != nil {
return err
}
p.ThreadID = threadID
} else {
var threadIsLocked bool
if err = QueryRowTxSQL(tx, "SELECT locked FROM DBPREFIXthreads WHERE id = ?",
if err = QueryRow(opts, "SELECT locked FROM DBPREFIXthreads WHERE id = ?",
[]any{p.ThreadID}, []any{&threadIsLocked}); err != nil {
return err
}
@ -408,40 +436,26 @@ func (p *Post) InsertWithContext(ctx context.Context, tx *sql.Tx, bumpThread boo
}
}
if _, err = ExecContextSQL(ctx, tx, insertSQL,
if _, err = Exec(opts, insertSQL,
p.ThreadID, p.IsTopPost, p.IP, p.Name, p.Tripcode, p.IsRoleSignature, p.Email, p.Subject,
p.Message, p.MessageRaw, p.Password, p.Flag, p.Country,
); err != nil {
return err
}
if p.ID, err = getLatestID("DBPREFIXposts", tx); err != nil {
if p.ID, err = getLatestID(opts, "DBPREFIXposts"); err != nil {
return err
}
if bumpThread {
if _, err = ExecContextSQL(ctx, tx, bumpSQL, p.ThreadID); err != nil {
if _, err = Exec(opts, bumpSQL, p.ThreadID); err != nil {
return err
}
}
if len(requestOptions) == 0 {
return opts.Tx.Commit()
}
return nil
}
func (p *Post) Insert(bumpThread bool, boardID int, locked bool, stickied bool, anchored bool, cyclical bool) error {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
tx, err := BeginContextTx(ctx)
if err != nil {
return err
}
defer tx.Rollback()
if err = p.InsertWithContext(ctx, tx, bumpThread, boardID, locked, stickied, anchored, cyclical); err != nil {
return err
}
return tx.Commit()
}
// CyclicThreadPost represents a post that should be deleted in a cyclic thread
type CyclicThreadPost struct {
PostID int // sql: post_id
@ -515,7 +529,7 @@ func (p *Post) WebPath() string {
webRoot := config.GetSystemCriticalConfig().WebRoot
const query = "SELECT op_id, dir FROM DBPREFIXv_top_post_board_dir WHERE id = ?"
err := QueryRowSQL(query, []any{p.ID}, []any{&p.opID, &p.boardDir})
err := QueryRow(nil, query, []any{p.ID}, []any{&p.opID, &p.boardDir})
if err != nil {
return webRoot
}

View file

@ -39,7 +39,7 @@ func PreloadModule(l *lua.LState) int {
})
}
rows, err := QuerySQL(queryStr, queryArgs...)
rows, err := Query(nil, queryStr, queryArgs...)
l.Push(luar.New(l, rows))
l.Push(luar.New(l, err))
@ -57,7 +57,7 @@ func PreloadModule(l *lua.LState) int {
execArgs = append(execArgs, arg)
})
}
result, err := ExecSQL(execStr)
result, err := Exec(nil, execStr)
l.Push(luar.New(l, result))
l.Push(luar.New(l, err))

View file

@ -96,7 +96,7 @@ func GetSectionFromName(name string) (*Section, error) {
// DeleteSection deletes a section from the database and resets the AllSections array
func DeleteSection(id int) error {
const query = `DELETE FROM DBPREFIXsections WHERE id = ?`
_, err := ExecSQL(query, id)
_, err := Exec(nil, query, id)
if err != nil {
return err
}
@ -105,38 +105,47 @@ func DeleteSection(id int) error {
// NewSection creates a new board section in the database and returns a *Section struct pointer.
// If position < 0, it will use the ID
func NewSection(name string, abbreviation string, hidden bool, position int) (*Section, error) {
func NewSection(name string, abbreviation string, hidden bool, position int, requestOpts ...*RequestOptions) (*Section, error) {
const sqlINSERT = `INSERT INTO DBPREFIXsections (name, abbreviation, hidden, position) VALUES (?,?,?,?)`
const sqlPosition = `SELECT COALESCE(MAX(position) + 1, 1) FROM DBPREFIXsections`
tx, err := BeginTx()
if err != nil {
return nil, err
var opts *RequestOptions
var err error
shouldCommit := len(requestOpts) == 0
if shouldCommit {
opts = &RequestOptions{}
opts.Tx, err = BeginTx()
if err != nil {
return nil, err
}
opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer func() {
opts.Cancel()
opts.Tx.Rollback()
}()
} else {
opts = requestOpts[0]
}
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer func() {
cancel()
tx.Rollback()
}()
if position < 0 {
// position not specified
err = QueryRowContextSQL(ctx, tx, sqlPosition, nil, []any{&position})
err = QueryRow(opts, sqlPosition, nil, []any{&position})
if errors.Is(err, sql.ErrNoRows) {
position = 1
} else if err != nil {
return nil, err
}
}
if _, err = ExecContextSQL(ctx, tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
if _, err = Exec(opts, sqlINSERT, name, abbreviation, hidden, position); err != nil {
return nil, err
}
id, err := getLatestID("DBPREFIXsections", tx)
id, err := getLatestID(opts, "DBPREFIXsections")
if err != nil {
return nil, err
}
if err = tx.Commit(); err != nil {
return nil, err
if shouldCommit {
if err = opts.Tx.Commit(); err != nil {
return nil, err
}
}
return &Section{
ID: id,
@ -147,9 +156,15 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
}, nil
}
func (s *Section) UpdateValues() error {
func (s *Section) UpdateValues(requestOpts ...*RequestOptions) error {
opts := setupOptions(requestOpts...)
if opts.Context == context.Background() {
opts.Context, opts.Cancel = context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer opts.Cancel()
}
var count int
err := QueryRowTimeoutSQL(nil, `SELECT COUNT(*) FROM DBPREFIXsections WHERE id = ?`, []any{s.ID}, []any{&count})
err := QueryRow(opts, `SELECT COUNT(*) FROM DBPREFIXsections WHERE id = ?`, []any{s.ID}, []any{&count})
if errors.Is(err, sql.ErrNoRows) {
return ErrSectionDoesNotExist
}
@ -157,6 +172,6 @@ func (s *Section) UpdateValues() error {
return err
}
const query = `UPDATE DBPREFIXsections set name = ?, abbreviation = ?, position = ?, hidden = ? WHERE id = ?`
_, err = ExecTimeoutSQL(nil, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
_, err = Exec(opts, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
return err
}

View file

@ -171,7 +171,7 @@ func setupAndProvisionMockDB(t *testing.T, mock sqlmock.Sqlmock, dbType string,
ExpectQuery().WithArgs("test").
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
_, err := ExecSQL("CREATE DATABASE gochan")
_, err := Exec(nil, "CREATE DATABASE gochan")
if err != nil {
return err
}

View file

@ -1,7 +1,6 @@
package gcsql
import (
"context"
"database/sql"
"errors"
"fmt"
@ -21,27 +20,27 @@ var (
)
// CreateThread creates a new thread in the database with the given board ID and statuses
func CreateThread(tx *sql.Tx, boardID int, locked bool, stickied bool, anchored bool, cyclic bool) (threadID int, err error) {
func CreateThread(requestOptions *RequestOptions, boardID int, locked bool, stickied bool, anchored bool, cyclic bool) (threadID int, err error) {
const lockedQuery = `SELECT locked FROM DBPREFIXboards WHERE id = ?`
const insertQuery = `INSERT INTO DBPREFIXthreads (board_id, locked, stickied, anchored, cyclical) VALUES (?,?,?,?,?)`
var boardIsLocked bool
if err = QueryRowTxSQL(tx, lockedQuery, []any{boardID}, []any{&boardIsLocked}); err != nil {
if err = QueryRow(requestOptions, lockedQuery, []any{boardID}, []any{&boardIsLocked}); err != nil {
return 0, err
}
if boardIsLocked {
return 0, ErrBoardIsLocked
}
if _, err = ExecTxSQL(tx, insertQuery, boardID, locked, stickied, anchored, cyclic); err != nil {
if _, err = Exec(requestOptions, insertQuery, boardID, locked, stickied, anchored, cyclic); err != nil {
return 0, err
}
return threadID, QueryRowTxSQL(tx, "SELECT MAX(id) FROM DBPREFIXthreads", nil, []any{&threadID})
return threadID, QueryRow(requestOptions, "SELECT MAX(id) FROM DBPREFIXthreads", nil, []any{&threadID})
}
// GetThread returns a a thread object from the database, given its ID
func GetThread(threadID int) (*Thread, error) {
const query = selectThreadsBaseSQL + `WHERE id = ?`
thread := new(Thread)
err := QueryRowSQL(query, []any{threadID}, []any{
err := QueryRow(nil, query, []any{threadID}, []any{
&thread.ID, &thread.BoardID, &thread.Locked, &thread.Stickied, &thread.Anchored, &thread.Cyclic,
&thread.LastBump, &thread.DeletedAt, &thread.IsDeleted,
})
@ -52,7 +51,7 @@ func GetThread(threadID int) (*Thread, error) {
func GetPostThread(opID int) (*Thread, error) {
const query = selectThreadsBaseSQL + `WHERE id = (SELECT thread_id FROM DBPREFIXposts WHERE id = ? LIMIT 1)`
thread := new(Thread)
err := QueryRowSQL(query, []any{opID}, []any{
err := QueryRow(nil, query, []any{opID}, []any{
&thread.ID, &thread.BoardID, &thread.Locked, &thread.Stickied, &thread.Anchored, &thread.Cyclic,
&thread.LastBump, &thread.DeletedAt, &thread.IsDeleted,
})
@ -66,7 +65,7 @@ func GetPostThread(opID int) (*Thread, error) {
func GetTopPostThreadID(opID int) (int, error) {
const query = `SELECT thread_id FROM DBPREFIXposts WHERE id = ? and is_top_post`
var threadID int
err := QueryRowSQL(query, []any{opID}, []any{&threadID})
err := QueryRow(nil, query, []any{opID}, []any{&threadID})
if err == sql.ErrNoRows {
err = ErrThreadDoesNotExist
}
@ -81,7 +80,7 @@ func GetThreadsWithBoardID(boardID int, onlyNotDeleted bool) ([]Thread, error) {
if onlyNotDeleted {
query += " AND is_deleted = FALSE"
}
rows, err := QuerySQL(query, boardID)
rows, err := Query(nil, query, boardID)
if err != nil {
return nil, err
}
@ -104,7 +103,7 @@ func GetThreadReplyCountFromOP(opID int) (int, error) {
const query = `SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = (
SELECT thread_id FROM DBPREFIXposts WHERE id = ?) AND is_deleted = FALSE AND is_top_post = FALSE`
var num int
err := QueryRowSQL(query, []any{opID}, []any{&num})
err := QueryRow(nil, query, []any{opID}, []any{&num})
return num, err
}
@ -113,7 +112,7 @@ func ChangeThreadBoardID(threadID int, newBoardID int) error {
if !DoesBoardExistByID(newBoardID) {
return ErrBoardDoesNotExist
}
_, err := ExecSQL(`UPDATE DBPREFIXthreads SET board_id = ? WHERE id = ?`, newBoardID, threadID)
_, err := Exec(nil, "UPDATE DBPREFIXthreads SET board_id = ? WHERE id = ?", newBoardID, threadID)
return err
}
@ -135,15 +134,15 @@ func (t *Thread) GetReplyFileCount() (int, error) {
const query = `SELECT COUNT(filename) FROM DBPREFIXfiles WHERE post_id IN (
SELECT id FROM DBPREFIXposts WHERE thread_id = ? AND is_deleted = FALSE)`
var fileCount int
err := QueryRowSQL(query, []any{t.ID}, []any{&fileCount})
err := QueryRow(nil, query, []any{t.ID}, []any{&fileCount})
return fileCount, err
}
// GetReplyCount returns the number of posts in the thread, not including the top post or any deleted posts
func (t *Thread) GetReplyCount() (int, error) {
const query = `SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ? AND is_top_post = FALSE AND is_deleted = FALSE`
const query = "SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ? AND is_top_post = FALSE AND is_deleted = FALSE"
var numReplies int
err := QueryRowSQL(query, []any{t.ID}, []any{&numReplies})
err := QueryRow(nil, query, []any{t.ID}, []any{&numReplies})
return numReplies, err
}
@ -161,7 +160,7 @@ func (t *Thread) GetPosts(repliesOnly bool, boardPage bool, limit int) ([]Post,
query += " LIMIT " + strconv.Itoa(limit)
}
rows, err := QuerySQL(query, t.ID)
rows, err := Query(nil, query, t.ID)
if err != nil {
return nil, err
}
@ -185,7 +184,7 @@ func (t *Thread) GetPosts(repliesOnly bool, boardPage bool, limit int) ([]Post,
func (t *Thread) GetUploads() ([]Upload, error) {
const query = selectFilesBaseSQL + ` WHERE post_id IN (
SELECT id FROM DBPREFIXposts WHERE thread_id = ? and is_deleted = FALSE) AND filename != 'deleted'`
rows, err := QuerySQL(query, t.ID)
rows, err := Query(nil, query, t.ID)
if err != nil {
return nil, err
}
@ -223,18 +222,18 @@ func (t *Thread) UpdateAttribute(attribute string, value bool) error {
return fmt.Errorf("invalid thread attribute %q", attribute)
}
updateSQL += attribute + " = ? WHERE id = ?"
_, err := ExecSQL(updateSQL, value, t.ID)
_, err := Exec(nil, updateSQL, value, t.ID)
return err
}
// deleteThread updates the thread and sets it as deleted, as well as the posts where thread_id = threadID
func deleteThread(ctx context.Context, tx *sql.Tx, threadID int) error {
func deleteThread(opts *RequestOptions, threadID int) error {
const checkPostExistsSQL = `SELECT COUNT(*) FROM DBPREFIXposts WHERE thread_id = ?`
const deletePostsSQL = `UPDATE DBPREFIXposts SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE thread_id = ?`
const deleteThreadSQL = `UPDATE DBPREFIXthreads SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = ?`
var rowCount int
err := QueryRowContextSQL(ctx, tx, checkPostExistsSQL, []any{threadID}, []any{&rowCount})
err := QueryRow(opts, checkPostExistsSQL, []any{threadID}, []any{&rowCount})
if err != nil {
return err
}
@ -242,10 +241,10 @@ func deleteThread(ctx context.Context, tx *sql.Tx, threadID int) error {
return ErrThreadDoesNotExist
}
_, err = ExecContextSQL(ctx, tx, deletePostsSQL, threadID)
_, err = Exec(opts, deletePostsSQL, threadID)
if err != nil {
return err
}
_, err = ExecContextSQL(ctx, tx, deleteThreadSQL, threadID)
_, err = Exec(opts, deleteThreadSQL, threadID)
return err
}

View file

@ -1,7 +1,6 @@
package gcsql
import (
"context"
"database/sql"
"errors"
"fmt"
@ -25,7 +24,7 @@ func GetThreadFiles(post *Post) ([]Upload, error) {
query := selectFilesBaseSQL + `WHERE post_id IN (
SELECT id FROM DBPREFIXposts WHERE thread_id = (
SELECT thread_id FROM DBPREFIXposts WHERE id = ?)) AND filename != 'deleted'`
rows, err := QuerySQL(query, post.ID)
rows, err := Query(nil, query, post.ID)
if err != nil {
return nil, err
}
@ -45,17 +44,28 @@ func GetThreadFiles(post *Post) ([]Upload, error) {
}
// NextFileOrder gets what would be the next file_order value (not particularly useful until multi-file posting is implemented)
func (p *Post) NextFileOrder(ctx context.Context, tx *sql.Tx) (int, error) {
func (p *Post) NextFileOrder(requestOpts ...*RequestOptions) (int, error) {
opts := setupOptions(requestOpts...)
const query = `SELECT COALESCE(MAX(file_order) + 1, 0) FROM DBPREFIXfiles WHERE post_id = ?`
var next int
err := QueryRowContextSQL(ctx, tx, query, []any{p.ID}, []any{&next})
err := QueryRow(opts, query, []any{p.ID}, []any{&next})
return next, err
}
func (p *Post) AttachFileTx(tx *sql.Tx, upload *Upload) error {
func (p *Post) AttachFile(upload *Upload, requestOpts ...*RequestOptions) error {
if upload == nil {
return nil // no upload to attach, so no error
}
opts := setupOptions(requestOpts...)
shouldCommit := opts.Tx == nil
var err error
if shouldCommit {
opts.Tx, err = BeginTx()
if err != nil {
return err
}
defer opts.Tx.Rollback()
}
_, err, recovered := events.TriggerEvent("incoming-upload", upload)
if recovered {
@ -72,35 +82,30 @@ func (p *Post) AttachFileTx(tx *sql.Tx, upload *Upload) error {
if upload.ID > 0 {
return ErrAlreadyAttached
}
if upload.FileOrder < 1 {
upload.FileOrder, err = p.NextFileOrder(context.Background(), tx)
upload.FileOrder, err = p.NextFileOrder(opts)
if err != nil {
return err
}
}
upload.PostID = p.ID
if _, err = ExecTxSQL(tx, insertSQL,
if _, err = Exec(opts, insertSQL,
&upload.PostID, &upload.FileOrder, &upload.OriginalFilename, &upload.Filename, &upload.Checksum, &upload.FileSize,
&upload.IsSpoilered, &upload.ThumbnailWidth, &upload.ThumbnailHeight, &upload.Width, &upload.Height,
); err != nil {
return err
}
upload.ID, err = getLatestID("DBPREFIXfiles", tx)
return err
}
func (p *Post) AttachFile(upload *Upload) error {
tx, err := BeginTx()
upload.ID, err = getLatestID(opts, "DBPREFIXfiles")
if err != nil {
return err
}
defer tx.Rollback()
if err = p.AttachFileTx(tx, upload); err != nil {
return err
if shouldCommit {
if err = opts.Tx.Commit(); err != nil {
return err
}
}
return tx.Commit()
return nil
}
// GetUploadFilenameAndBoard returns the filename (or an empty string) and
@ -112,7 +117,7 @@ func GetUploadFilenameAndBoard(postID int) (string, string, error) {
JOIN DBPREFIXboards ON DBPREFIXboards.id = board_id
WHERE DBPREFIXposts.id = ?`
var filename, dir string
err := QueryRowSQL(query, []any{postID}, []any{&filename, &dir})
err := QueryRow(nil, query, []any{postID}, []any{&filename, &dir})
if errors.Is(err, sql.ErrNoRows) {
return "", "", nil
} else if err != nil {

View file

@ -51,6 +51,39 @@ type intOrStringConstraint interface {
int | string
}
// RequestOptions is used to pass an optional context, transaction, and any other things to the various SQL functions
// in a future-proof way
type RequestOptions struct {
Context context.Context
Tx *sql.Tx
Cancel context.CancelFunc
}
func setupOptions(opts ...*RequestOptions) *RequestOptions {
if len(opts) == 0 || opts[0] == nil {
return &RequestOptions{Context: context.Background()}
}
if opts[0].Context == nil {
opts[0].Context = context.Background()
}
return opts[0]
}
// Query is a wrapper for QueryContextSQL that uses the given options, or defaults to a background context if nil
func Query(opts *RequestOptions, query string, a ...any) (*sql.Rows, error) {
return gcdb.Query(opts, query, a...)
}
// QueryRow is a wrapper for QueryRowContextSQL that uses the given options, or defaults to a background context if nil
func QueryRow(opts *RequestOptions, query string, values, out []any) error {
return gcdb.QueryRow(opts, query, values, out)
}
// Exec is a wrapper for ExecContextSQL that uses the given options, or defaults to a background context if nil
func Exec(opts *RequestOptions, query string, values ...any) (sql.Result, error) {
return gcdb.Exec(opts, query, values...)
}
// BeginTx begins a new transaction for the gochan database. It uses a background context
func BeginTx() (*sql.Tx, error) {
return BeginContextTx(context.Background())
@ -111,7 +144,7 @@ func SetupSQLString(query string, dbConn *GCDB) (string, error) {
return prepared, err
}
// Close closes the connection to the SQL database
// Close closes the connection to the SQL database if it is open
func Close() error {
if gcdb != nil {
return gcdb.Close()
@ -121,6 +154,7 @@ func Close() error {
/*
ExecSQL executes the given SQL statement with the given parameters
Example:
var intVal int
@ -136,8 +170,8 @@ func ExecSQL(query string, values ...any) (sql.Result, error) {
}
/*
ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil)
ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil).
Deprecated: Use Exec instead
Example:
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(sqlCfg.DBTimeoutSeconds) * time.Second)
@ -154,6 +188,7 @@ func ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...an
return gcdb.ExecContextSQL(ctx, tx, sqlStr, values...)
}
// ExecTimeoutSQL is a helper function for executing a SQL statement with the configured timeout in seconds
func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
@ -162,7 +197,9 @@ func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error
}
/*
ExecTxSQL automatically escapes the given values and caches the statement
ExecTxSQL executes the given SQL statement with the given transaction and parameters.
Deprecated: Use Exec instead with a transaction in the RequestOptions
Example:
tx, err := BeginTx()
@ -190,8 +227,8 @@ func ExecTxSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
}
/*
QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[].
Deprecated: Use QueryRow instead
Example:
@ -211,7 +248,8 @@ func QueryRowSQL(query string, values, out []any) error {
/*
QueryRowContextSQL gets a row from the database with the values in values[] and fills the respective pointers in out[]
using the given context as a deadline, and the given transaction (if non-nil)
using the given context as a deadline, and the given transaction (if non-nil).
Deprecated: Use QueryRow instead with an optional context and/or tx in the RequestOptions
Example:
@ -239,8 +277,8 @@ func QueryRowTimeoutSQL(tx *sql.Tx, query string, values, out []any) error {
}
/*
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[].
Deprecated: Use QueryRow instead with a transaction in the RequestOptions
Example:
@ -261,8 +299,8 @@ func QueryRowTxSQL(tx *sql.Tx, query string, values, out []any) error {
}
/*
QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[].
Deprecated: Use Query instead
Example:
@ -285,7 +323,8 @@ func QuerySQL(query string, a ...any) (*sql.Rows, error) {
/*
QueryContextSQL queries the database with a prepared statement and the given parameters, using the given context
for a deadline
for a deadline.
Deprecated: Use Query instead with an optional context/transaction in the RequestOptions
Example:
@ -317,7 +356,9 @@ func QueryTimeoutSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, context.Can
/*
QueryTxSQL gets all rows from the db using the transaction tx with the values in values[] and fills the
respective pointers in out[]. Automatically escapes the given values and caches the query
respective pointers in out[].
Deprecated: Use Query instead with a transaction in the RequestOptions
Example:
tx, err := BeginTx()
@ -347,6 +388,7 @@ func QueryTxSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, error) {
return rows, stmt.Close()
}
// ParseSQLTimeString attempts to parse a string into a time.Time object using the known SQL date/time formats
func ParseSQLTimeString(str string) (time.Time, error) {
var t time.Time
var err error
@ -359,22 +401,10 @@ func ParseSQLTimeString(str string) (time.Time, error) {
}
// getLatestID returns the latest inserted id column value from the given table
func getLatestID(tableName string, tx *sql.Tx) (id int, err error) {
func getLatestID(opts *RequestOptions, tableName string) (id int, err error) {
opts = setupOptions(opts)
query := `SELECT MAX(id) FROM ` + tableName
if tx != nil {
var stmt *sql.Stmt
stmt, err = PrepareSQL(query, tx)
if err != nil {
return 0, err
}
defer stmt.Close()
if err = stmt.QueryRow().Scan(&id); err != nil {
return
}
err = stmt.Close()
} else {
err = QueryRowSQL(query, nil, []any{&id})
}
QueryRow(opts, query, nil, []any{&id})
return
}
@ -393,7 +423,7 @@ func doesTableExist(tableName string) (bool, error) {
}
var count int
err := QueryRowSQL(existQuery, []any{config.GetSystemCriticalConfig().DBprefix + tableName}, []any{&count})
err := QueryRow(nil, existQuery, []any{config.GetSystemCriticalConfig().DBprefix + tableName}, []any{&count})
if err != nil {
return false, err
}
@ -404,7 +434,7 @@ func doesTableExist(tableName string) (bool, error) {
func GetComponentVersion(componentKey string) (int, error) {
const sql = `SELECT version FROM DBPREFIXdatabase_version WHERE component = ?`
var version int
err := QueryRowSQL(sql, []any{componentKey}, []any{&version})
err := QueryRow(nil, sql, []any{componentKey}, []any{&version})
return version, err
}
@ -434,7 +464,7 @@ func doesGochanPrefixTableExist() (bool, error) {
}
var count int
err := QueryRowSQL(prefixTableExist, []any{}, []any{&count})
err := QueryRow(nil, prefixTableExist, []any{}, []any{&count})
if err != nil && err != sql.ErrNoRows {
return false, err
}