mirror of
https://github.com/Eggbertx/gochan.git
synced 2025-08-02 15:06:23 -07:00
Add SQL timeout wrapper functions
This commit is contained in:
parent
b2b58213e2
commit
be77ec64f9
3 changed files with 59 additions and 49 deletions
|
@ -1,6 +1,7 @@
|
||||||
package gcsql
|
package gcsql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
@ -21,11 +22,14 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
|
||||||
}
|
}
|
||||||
query += " ORDER BY position ASC, name ASC"
|
query += " ORDER BY position ASC, name ASC"
|
||||||
|
|
||||||
rows, err := QuerySQL(query)
|
rows, cancel, err := QueryTimeoutSQL(nil, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
rows.Close()
|
||||||
|
}()
|
||||||
var sections []Section
|
var sections []Section
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var section Section
|
var section Section
|
||||||
|
@ -35,7 +39,7 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
|
||||||
}
|
}
|
||||||
sections = append(sections, section)
|
sections = append(sections, section)
|
||||||
}
|
}
|
||||||
return sections, nil
|
return sections, rows.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrCreateDefaultSectionID creates the default section if no sections have been created yet,
|
// getOrCreateDefaultSectionID creates the default section if no sections have been created yet,
|
||||||
|
@ -43,7 +47,8 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
|
||||||
func getOrCreateDefaultSectionID() (sectionID int, err error) {
|
func getOrCreateDefaultSectionID() (sectionID int, err error) {
|
||||||
const query = `SELECT id FROM DBPREFIXsections WHERE name = 'Main'`
|
const query = `SELECT id FROM DBPREFIXsections WHERE name = 'Main'`
|
||||||
var id int
|
var id int
|
||||||
err = QueryRowSQL(query, nil, []any{&id})
|
|
||||||
|
err = QueryRowTimeoutSQL(nil, query, nil, []any{&id})
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
var section *Section
|
var section *Section
|
||||||
if section, err = NewSection("Main", "main", false, -1); err != nil {
|
if section, err = NewSection("Main", "main", false, -1); err != nil {
|
||||||
|
@ -60,7 +65,7 @@ func getOrCreateDefaultSectionID() (sectionID int, err error) {
|
||||||
func GetSectionFromID(id int) (*Section, error) {
|
func GetSectionFromID(id int) (*Section, error) {
|
||||||
const query = `SELECT id, name, abbreviation, position, hidden FROM DBPREFIXsections WHERE id = ?`
|
const query = `SELECT id, name, abbreviation, position, hidden FROM DBPREFIXsections WHERE id = ?`
|
||||||
var section Section
|
var section Section
|
||||||
err := QueryRowSQL(query, []any{id}, []any{
|
err := QueryRowTimeoutSQL(nil, query, []any{id}, []any{
|
||||||
§ion.ID, §ion.Name, §ion.Abbreviation, §ion.Position, §ion.Hidden,
|
§ion.ID, §ion.Name, §ion.Abbreviation, §ion.Position, §ion.Hidden,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -88,18 +93,22 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
if position < 0 {
|
if position < 0 {
|
||||||
// position not specified
|
// position not specified
|
||||||
err = QueryRowTxSQL(tx, sqlPosition, nil, []any{&position})
|
err = QueryRowContextSQL(ctx, tx, sqlPosition, nil, []any{&position})
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
position = 1
|
position = 1
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, err = ExecTxSQL(tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
|
if _, err = ExecContextSQL(ctx, tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
id, err := getLatestID("DBPREFIXsections", tx)
|
id, err := getLatestID("DBPREFIXsections", tx)
|
||||||
|
@ -120,6 +129,6 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
|
||||||
|
|
||||||
func (s *Section) UpdateValues() error {
|
func (s *Section) UpdateValues() error {
|
||||||
const query = `UPDATE DBPREFIXsections set name = ?, abbreviation = ?, position = ?, hidden = ? WHERE id = ?`
|
const query = `UPDATE DBPREFIXsections set name = ?, abbreviation = ?, position = ?, hidden = ? WHERE id = ?`
|
||||||
_, err := ExecSQL(query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
|
_, err := ExecTimeoutSQL(nil, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,10 +22,7 @@ func createDefaultAdminIfNoStaff() error {
|
||||||
const query = `SELECT COUNT(id) FROM DBPREFIXstaff`
|
const query = `SELECT COUNT(id) FROM DBPREFIXstaff`
|
||||||
var count int
|
var count int
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
err := QueryRowTimeoutSQL(nil, query, nil, []any{&count})
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := QueryRowContextSQL(ctx, nil, query, nil, []any{&count})
|
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -42,10 +39,7 @@ func NewStaff(username string, password string, rank int) (*Staff, error) {
|
||||||
VALUES(?,?,?)`
|
VALUES(?,?,?)`
|
||||||
passwordChecksum := gcutil.BcryptSum(password)
|
passwordChecksum := gcutil.BcryptSum(password)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
_, err := ExecTimeoutSQL(nil, sqlINSERT, username, passwordChecksum, rank)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := ExecContextSQL(ctx, nil, sqlINSERT, username, passwordChecksum, rank)
|
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -62,10 +56,7 @@ func NewStaff(username string, password string, rank int) (*Staff, error) {
|
||||||
func (s *Staff) SetActive(active bool) error {
|
func (s *Staff) SetActive(active bool) error {
|
||||||
const updateActive = `UPDATE DBPREFIXstaff SET is_active = FALSE WHERE username = ?`
|
const updateActive = `UPDATE DBPREFIXstaff SET is_active = FALSE WHERE username = ?`
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
_, err := ExecTimeoutSQL(nil, updateActive, s.Username)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := ExecContextSQL(ctx, nil, updateActive, s.Username)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -109,10 +100,7 @@ func UpdatePassword(username string, newPassword string) error {
|
||||||
const sqlUPDATE = `UPDATE DBPREFIXstaff SET password_checksum = ? WHERE username = ?`
|
const sqlUPDATE = `UPDATE DBPREFIXstaff SET password_checksum = ? WHERE username = ?`
|
||||||
checksum := gcutil.BcryptSum(newPassword)
|
checksum := gcutil.BcryptSum(newPassword)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
_, err := ExecTimeoutSQL(nil, sqlUPDATE, checksum, username)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := ExecContextSQL(ctx, nil, sqlUPDATE, checksum, username)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,10 +143,7 @@ func GetStaffUsernameFromID(id int) (string, error) {
|
||||||
const query = `SELECT username FROM DBPREFIXstaff WHERE id = ?`
|
const query = `SELECT username FROM DBPREFIXstaff WHERE id = ?`
|
||||||
var username string
|
var username string
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
err := QueryRowTimeoutSQL(nil, query, []any{id}, []any{&username})
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := QueryRowContextSQL(ctx, nil, query, []any{id}, []any{&username})
|
|
||||||
return username, err
|
return username, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,32 +151,20 @@ func GetStaffID(username string) (int, error) {
|
||||||
const query = `SELECT id FROM DBPREFIXstaff WHERE username = ?`
|
const query = `SELECT id FROM DBPREFIXstaff WHERE username = ?`
|
||||||
var id int
|
var id int
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
err := QueryRowTimeoutSQL(nil, query, []any{username}, []any{&id})
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := QueryRowContextSQL(ctx, nil, query, []any{username}, []any{&id})
|
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStaffBySession gets the staff that is logged in in the given session
|
// GetStaffBySession gets the staff that is logged in in the given session
|
||||||
func GetStaffBySession(session string) (*Staff, error) {
|
func GetStaffBySession(session string) (*Staff, error) {
|
||||||
const query = `SELECT
|
const query = `SELECT
|
||||||
staff.id,
|
staff.id, staff.username, staff.password_checksum, staff.global_rank, staff.added_on, staff.last_login
|
||||||
staff.username,
|
|
||||||
staff.password_checksum,
|
|
||||||
staff.global_rank,
|
|
||||||
staff.added_on,
|
|
||||||
staff.last_login
|
|
||||||
FROM DBPREFIXstaff as staff
|
FROM DBPREFIXstaff as staff
|
||||||
JOIN DBPREFIXsessions as sessions
|
JOIN DBPREFIXsessions as sessions ON sessions.staff_id = staff.id
|
||||||
ON sessions.staff_id = staff.id
|
|
||||||
WHERE sessions.data = ?`
|
WHERE sessions.data = ?`
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
staff := new(Staff)
|
staff := new(Staff)
|
||||||
err := QueryRowContextSQL(ctx, nil, query, []any{session}, []any{
|
err := QueryRowTimeoutSQL(nil, query, []any{session}, []any{
|
||||||
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn, &staff.LastLogin})
|
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn, &staff.LastLogin})
|
||||||
return staff, err
|
return staff, err
|
||||||
}
|
}
|
||||||
|
@ -203,11 +176,8 @@ func GetStaffByUsername(username string, onlyActive bool) (*Staff, error) {
|
||||||
if onlyActive {
|
if onlyActive {
|
||||||
query += ` AND is_active = TRUE`
|
query += ` AND is_active = TRUE`
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
staff := new(Staff)
|
staff := new(Staff)
|
||||||
err := QueryRowContextSQL(ctx, nil, query, []any{username}, []any{
|
err := QueryRowTimeoutSQL(nil, query, []any{username}, []any{
|
||||||
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn,
|
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn,
|
||||||
&staff.LastLogin, &staff.IsActive,
|
&staff.LastLogin, &staff.IsActive,
|
||||||
})
|
})
|
||||||
|
|
|
@ -123,6 +123,13 @@ func ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...an
|
||||||
return gcdb.ExecContextSQL(ctx, tx, sqlStr, values...)
|
return gcdb.ExecContextSQL(ctx, tx, sqlStr, values...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return ExecContextSQL(ctx, tx, sqlStr, values...)
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
ExecTxSQL automatically escapes the given values and caches the statement
|
ExecTxSQL automatically escapes the given values and caches the statement
|
||||||
Example:
|
Example:
|
||||||
|
@ -191,6 +198,15 @@ func QueryRowContextSQL(ctx context.Context, tx *sql.Tx, query string, values, o
|
||||||
return gcdb.QueryRowContextSQL(ctx, tx, query, values, out)
|
return gcdb.QueryRowContextSQL(ctx, tx, query, values, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryRowTimeoutSQL is a helper function for querying a single row with the configured default timeout.
|
||||||
|
// It creates a context with the default timeout to only be used for this query and then disposed.
|
||||||
|
// It should only be used by a function that does a single SQL query, otherwise use QueryRowContextSQL
|
||||||
|
func QueryRowTimeoutSQL(tx *sql.Tx, query string, values, out []any) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return QueryRowContextSQL(ctx, tx, query, values, out)
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
|
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
|
||||||
Automatically escapes the given values and caches the query
|
Automatically escapes the given values and caches the query
|
||||||
|
@ -253,6 +269,21 @@ func QueryContextSQL(ctx context.Context, tx *sql.Tx, query string, a ...any) (*
|
||||||
return gcdb.QueryContextSQL(ctx, tx, query, a...)
|
return gcdb.QueryContextSQL(ctx, tx, query, a...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryTimeoutSQL creates a new context with the configured default timeout and passes it and
|
||||||
|
// the given transaction, query, and parameters to QueryContextSQL. If it returns an error,
|
||||||
|
// the context is cancelled, and the error is returned. Otherwise, it returns the rows,
|
||||||
|
// cancel function (for the calling function to call later), and nil error. It should only be used
|
||||||
|
// if the calling function is only doing one SQL query, otherwise use QueryContextSQL.
|
||||||
|
func QueryTimeoutSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, context.CancelFunc, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
|
||||||
|
rows, err := QueryContextSQL(ctx, tx, query, a...)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
return nil, cancel, err
|
||||||
|
}
|
||||||
|
return rows, cancel, nil
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
QueryTxSQL gets all rows from the db using the transaction tx with the values in values[] and fills the
|
QueryTxSQL gets all rows from the db using the transaction tx with the values in values[] and fills the
|
||||||
respective pointers in out[]. Automatically escapes the given values and caches the query
|
respective pointers in out[]. Automatically escapes the given values and caches the query
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue