mirror of
https://github.com/Eggbertx/gochan.git
synced 2025-08-02 15:06:23 -07:00
Add SQL timeout wrapper functions
This commit is contained in:
parent
b2b58213e2
commit
be77ec64f9
3 changed files with 59 additions and 49 deletions
|
@ -1,6 +1,7 @@
|
|||
package gcsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
@ -21,11 +22,14 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
|
|||
}
|
||||
query += " ORDER BY position ASC, name ASC"
|
||||
|
||||
rows, err := QuerySQL(query)
|
||||
rows, cancel, err := QueryTimeoutSQL(nil, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() {
|
||||
cancel()
|
||||
rows.Close()
|
||||
}()
|
||||
var sections []Section
|
||||
for rows.Next() {
|
||||
var section Section
|
||||
|
@ -35,7 +39,7 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
|
|||
}
|
||||
sections = append(sections, section)
|
||||
}
|
||||
return sections, nil
|
||||
return sections, rows.Close()
|
||||
}
|
||||
|
||||
// getOrCreateDefaultSectionID creates the default section if no sections have been created yet,
|
||||
|
@ -43,7 +47,8 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
|
|||
func getOrCreateDefaultSectionID() (sectionID int, err error) {
|
||||
const query = `SELECT id FROM DBPREFIXsections WHERE name = 'Main'`
|
||||
var id int
|
||||
err = QueryRowSQL(query, nil, []any{&id})
|
||||
|
||||
err = QueryRowTimeoutSQL(nil, query, nil, []any{&id})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var section *Section
|
||||
if section, err = NewSection("Main", "main", false, -1); err != nil {
|
||||
|
@ -60,7 +65,7 @@ func getOrCreateDefaultSectionID() (sectionID int, err error) {
|
|||
func GetSectionFromID(id int) (*Section, error) {
|
||||
const query = `SELECT id, name, abbreviation, position, hidden FROM DBPREFIXsections WHERE id = ?`
|
||||
var section Section
|
||||
err := QueryRowSQL(query, []any{id}, []any{
|
||||
err := QueryRowTimeoutSQL(nil, query, []any{id}, []any{
|
||||
§ion.ID, §ion.Name, §ion.Abbreviation, §ion.Position, §ion.Hidden,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -88,18 +93,22 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer func() {
|
||||
cancel()
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
if position < 0 {
|
||||
// position not specified
|
||||
err = QueryRowTxSQL(tx, sqlPosition, nil, []any{&position})
|
||||
err = QueryRowContextSQL(ctx, tx, sqlPosition, nil, []any{&position})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
position = 1
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if _, err = ExecTxSQL(tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
|
||||
if _, err = ExecContextSQL(ctx, tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := getLatestID("DBPREFIXsections", tx)
|
||||
|
@ -120,6 +129,6 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
|
|||
|
||||
func (s *Section) UpdateValues() error {
|
||||
const query = `UPDATE DBPREFIXsections set name = ?, abbreviation = ?, position = ?, hidden = ? WHERE id = ?`
|
||||
_, err := ExecSQL(query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
|
||||
_, err := ExecTimeoutSQL(nil, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -22,10 +22,7 @@ func createDefaultAdminIfNoStaff() error {
|
|||
const query = `SELECT COUNT(id) FROM DBPREFIXstaff`
|
||||
var count int
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
err := QueryRowContextSQL(ctx, nil, query, nil, []any{&count})
|
||||
err := QueryRowTimeoutSQL(nil, query, nil, []any{&count})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
@ -42,10 +39,7 @@ func NewStaff(username string, password string, rank int) (*Staff, error) {
|
|||
VALUES(?,?,?)`
|
||||
passwordChecksum := gcutil.BcryptSum(password)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := ExecContextSQL(ctx, nil, sqlINSERT, username, passwordChecksum, rank)
|
||||
_, err := ExecTimeoutSQL(nil, sqlINSERT, username, passwordChecksum, rank)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -62,10 +56,7 @@ func NewStaff(username string, password string, rank int) (*Staff, error) {
|
|||
func (s *Staff) SetActive(active bool) error {
|
||||
const updateActive = `UPDATE DBPREFIXstaff SET is_active = FALSE WHERE username = ?`
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := ExecContextSQL(ctx, nil, updateActive, s.Username)
|
||||
_, err := ExecTimeoutSQL(nil, updateActive, s.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -109,10 +100,7 @@ func UpdatePassword(username string, newPassword string) error {
|
|||
const sqlUPDATE = `UPDATE DBPREFIXstaff SET password_checksum = ? WHERE username = ?`
|
||||
checksum := gcutil.BcryptSum(newPassword)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := ExecContextSQL(ctx, nil, sqlUPDATE, checksum, username)
|
||||
_, err := ExecTimeoutSQL(nil, sqlUPDATE, checksum, username)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -155,10 +143,7 @@ func GetStaffUsernameFromID(id int) (string, error) {
|
|||
const query = `SELECT username FROM DBPREFIXstaff WHERE id = ?`
|
||||
var username string
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
err := QueryRowContextSQL(ctx, nil, query, []any{id}, []any{&username})
|
||||
err := QueryRowTimeoutSQL(nil, query, []any{id}, []any{&username})
|
||||
return username, err
|
||||
}
|
||||
|
||||
|
@ -166,32 +151,20 @@ func GetStaffID(username string) (int, error) {
|
|||
const query = `SELECT id FROM DBPREFIXstaff WHERE username = ?`
|
||||
var id int
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
err := QueryRowContextSQL(ctx, nil, query, []any{username}, []any{&id})
|
||||
err := QueryRowTimeoutSQL(nil, query, []any{username}, []any{&id})
|
||||
return id, err
|
||||
}
|
||||
|
||||
// GetStaffBySession gets the staff that is logged in in the given session
|
||||
func GetStaffBySession(session string) (*Staff, error) {
|
||||
const query = `SELECT
|
||||
staff.id,
|
||||
staff.username,
|
||||
staff.password_checksum,
|
||||
staff.global_rank,
|
||||
staff.added_on,
|
||||
staff.last_login
|
||||
staff.id, staff.username, staff.password_checksum, staff.global_rank, staff.added_on, staff.last_login
|
||||
FROM DBPREFIXstaff as staff
|
||||
JOIN DBPREFIXsessions as sessions
|
||||
ON sessions.staff_id = staff.id
|
||||
JOIN DBPREFIXsessions as sessions ON sessions.staff_id = staff.id
|
||||
WHERE sessions.data = ?`
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
staff := new(Staff)
|
||||
err := QueryRowContextSQL(ctx, nil, query, []any{session}, []any{
|
||||
err := QueryRowTimeoutSQL(nil, query, []any{session}, []any{
|
||||
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn, &staff.LastLogin})
|
||||
return staff, err
|
||||
}
|
||||
|
@ -203,11 +176,8 @@ func GetStaffByUsername(username string, onlyActive bool) (*Staff, error) {
|
|||
if onlyActive {
|
||||
query += ` AND is_active = TRUE`
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
staff := new(Staff)
|
||||
err := QueryRowContextSQL(ctx, nil, query, []any{username}, []any{
|
||||
err := QueryRowTimeoutSQL(nil, query, []any{username}, []any{
|
||||
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn,
|
||||
&staff.LastLogin, &staff.IsActive,
|
||||
})
|
||||
|
|
|
@ -123,6 +123,13 @@ func ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...an
|
|||
return gcdb.ExecContextSQL(ctx, tx, sqlStr, values...)
|
||||
}
|
||||
|
||||
func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return ExecContextSQL(ctx, tx, sqlStr, values...)
|
||||
}
|
||||
|
||||
/*
|
||||
ExecTxSQL automatically escapes the given values and caches the statement
|
||||
Example:
|
||||
|
@ -191,6 +198,15 @@ func QueryRowContextSQL(ctx context.Context, tx *sql.Tx, query string, values, o
|
|||
return gcdb.QueryRowContextSQL(ctx, tx, query, values, out)
|
||||
}
|
||||
|
||||
// QueryRowTimeoutSQL is a helper function for querying a single row with the configured default timeout.
|
||||
// It creates a context with the default timeout to only be used for this query and then disposed.
|
||||
// It should only be used by a function that does a single SQL query, otherwise use QueryRowContextSQL
|
||||
func QueryRowTimeoutSQL(tx *sql.Tx, query string, values, out []any) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
defer cancel()
|
||||
return QueryRowContextSQL(ctx, tx, query, values, out)
|
||||
}
|
||||
|
||||
/*
|
||||
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
|
||||
Automatically escapes the given values and caches the query
|
||||
|
@ -253,6 +269,21 @@ func QueryContextSQL(ctx context.Context, tx *sql.Tx, query string, a ...any) (*
|
|||
return gcdb.QueryContextSQL(ctx, tx, query, a...)
|
||||
}
|
||||
|
||||
// QueryTimeoutSQL creates a new context with the configured default timeout and passes it and
|
||||
// the given transaction, query, and parameters to QueryContextSQL. If it returns an error,
|
||||
// the context is cancelled, and the error is returned. Otherwise, it returns the rows,
|
||||
// cancel function (for the calling function to call later), and nil error. It should only be used
|
||||
// if the calling function is only doing one SQL query, otherwise use QueryContextSQL.
|
||||
func QueryTimeoutSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, context.CancelFunc, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||
rows, err := QueryContextSQL(ctx, tx, query, a...)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, cancel, err
|
||||
}
|
||||
return rows, cancel, nil
|
||||
}
|
||||
|
||||
/*
|
||||
QueryTxSQL gets all rows from the db using the transaction tx with the values in values[] and fills the
|
||||
respective pointers in out[]. Automatically escapes the given values and caches the query
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue