1
0
Fork 0
mirror of https://github.com/Eggbertx/gochan.git synced 2025-08-02 10:56:25 -07:00

Add SQL timeout wrapper functions

This commit is contained in:
Eggbertx 2024-05-30 13:16:13 -07:00
parent b2b58213e2
commit be77ec64f9
3 changed files with 59 additions and 49 deletions

View file

@ -1,6 +1,7 @@
package gcsql
import (
"context"
"database/sql"
"errors"
)
@ -21,11 +22,14 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
}
query += " ORDER BY position ASC, name ASC"
rows, err := QuerySQL(query)
rows, cancel, err := QueryTimeoutSQL(nil, query)
if err != nil {
return nil, err
}
defer rows.Close()
defer func() {
cancel()
rows.Close()
}()
var sections []Section
for rows.Next() {
var section Section
@ -35,7 +39,7 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
}
sections = append(sections, section)
}
return sections, nil
return sections, rows.Close()
}
// getOrCreateDefaultSectionID creates the default section if no sections have been created yet,
@ -43,7 +47,8 @@ func GetAllSections(onlyNonHidden bool) ([]Section, error) {
func getOrCreateDefaultSectionID() (sectionID int, err error) {
const query = `SELECT id FROM DBPREFIXsections WHERE name = 'Main'`
var id int
err = QueryRowSQL(query, nil, []any{&id})
err = QueryRowTimeoutSQL(nil, query, nil, []any{&id})
if errors.Is(err, sql.ErrNoRows) {
var section *Section
if section, err = NewSection("Main", "main", false, -1); err != nil {
@ -60,7 +65,7 @@ func getOrCreateDefaultSectionID() (sectionID int, err error) {
func GetSectionFromID(id int) (*Section, error) {
const query = `SELECT id, name, abbreviation, position, hidden FROM DBPREFIXsections WHERE id = ?`
var section Section
err := QueryRowSQL(query, []any{id}, []any{
err := QueryRowTimeoutSQL(nil, query, []any{id}, []any{
&section.ID, &section.Name, &section.Abbreviation, &section.Position, &section.Hidden,
})
if err != nil {
@ -88,18 +93,22 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
if err != nil {
return nil, err
}
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer func() {
cancel()
tx.Rollback()
}()
if position < 0 {
// position not specified
err = QueryRowTxSQL(tx, sqlPosition, nil, []any{&position})
err = QueryRowContextSQL(ctx, tx, sqlPosition, nil, []any{&position})
if errors.Is(err, sql.ErrNoRows) {
position = 1
} else if err != nil {
return nil, err
}
}
if _, err = ExecTxSQL(tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
if _, err = ExecContextSQL(ctx, tx, sqlINSERT, name, abbreviation, hidden, position); err != nil {
return nil, err
}
id, err := getLatestID("DBPREFIXsections", tx)
@ -120,6 +129,6 @@ func NewSection(name string, abbreviation string, hidden bool, position int) (*S
func (s *Section) UpdateValues() error {
const query = `UPDATE DBPREFIXsections set name = ?, abbreviation = ?, position = ?, hidden = ? WHERE id = ?`
_, err := ExecSQL(query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
_, err := ExecTimeoutSQL(nil, query, s.Name, s.Abbreviation, s.Position, s.Hidden, s.ID)
return err
}

View file

@ -22,10 +22,7 @@ func createDefaultAdminIfNoStaff() error {
const query = `SELECT COUNT(id) FROM DBPREFIXstaff`
var count int
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
err := QueryRowContextSQL(ctx, nil, query, nil, []any{&count})
err := QueryRowTimeoutSQL(nil, query, nil, []any{&count})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
@ -42,10 +39,7 @@ func NewStaff(username string, password string, rank int) (*Staff, error) {
VALUES(?,?,?)`
passwordChecksum := gcutil.BcryptSum(password)
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
_, err := ExecContextSQL(ctx, nil, sqlINSERT, username, passwordChecksum, rank)
_, err := ExecTimeoutSQL(nil, sqlINSERT, username, passwordChecksum, rank)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
@ -62,10 +56,7 @@ func NewStaff(username string, password string, rank int) (*Staff, error) {
func (s *Staff) SetActive(active bool) error {
const updateActive = `UPDATE DBPREFIXstaff SET is_active = FALSE WHERE username = ?`
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
_, err := ExecContextSQL(ctx, nil, updateActive, s.Username)
_, err := ExecTimeoutSQL(nil, updateActive, s.Username)
if err != nil {
return err
}
@ -109,10 +100,7 @@ func UpdatePassword(username string, newPassword string) error {
const sqlUPDATE = `UPDATE DBPREFIXstaff SET password_checksum = ? WHERE username = ?`
checksum := gcutil.BcryptSum(newPassword)
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
_, err := ExecContextSQL(ctx, nil, sqlUPDATE, checksum, username)
_, err := ExecTimeoutSQL(nil, sqlUPDATE, checksum, username)
return err
}
@ -155,10 +143,7 @@ func GetStaffUsernameFromID(id int) (string, error) {
const query = `SELECT username FROM DBPREFIXstaff WHERE id = ?`
var username string
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
err := QueryRowContextSQL(ctx, nil, query, []any{id}, []any{&username})
err := QueryRowTimeoutSQL(nil, query, []any{id}, []any{&username})
return username, err
}
@ -166,32 +151,20 @@ func GetStaffID(username string) (int, error) {
const query = `SELECT id FROM DBPREFIXstaff WHERE username = ?`
var id int
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
err := QueryRowContextSQL(ctx, nil, query, []any{username}, []any{&id})
err := QueryRowTimeoutSQL(nil, query, []any{username}, []any{&id})
return id, err
}
// GetStaffBySession gets the staff that is logged in in the given session
func GetStaffBySession(session string) (*Staff, error) {
const query = `SELECT
staff.id,
staff.username,
staff.password_checksum,
staff.global_rank,
staff.added_on,
staff.last_login
staff.id, staff.username, staff.password_checksum, staff.global_rank, staff.added_on, staff.last_login
FROM DBPREFIXstaff as staff
JOIN DBPREFIXsessions as sessions
ON sessions.staff_id = staff.id
JOIN DBPREFIXsessions as sessions ON sessions.staff_id = staff.id
WHERE sessions.data = ?`
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
staff := new(Staff)
err := QueryRowContextSQL(ctx, nil, query, []any{session}, []any{
err := QueryRowTimeoutSQL(nil, query, []any{session}, []any{
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn, &staff.LastLogin})
return staff, err
}
@ -203,11 +176,8 @@ func GetStaffByUsername(username string, onlyActive bool) (*Staff, error) {
if onlyActive {
query += ` AND is_active = TRUE`
}
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
staff := new(Staff)
err := QueryRowContextSQL(ctx, nil, query, []any{username}, []any{
err := QueryRowTimeoutSQL(nil, query, []any{username}, []any{
&staff.ID, &staff.Username, &staff.PasswordChecksum, &staff.Rank, &staff.AddedOn,
&staff.LastLogin, &staff.IsActive,
})

View file

@ -123,6 +123,13 @@ func ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...an
return gcdb.ExecContextSQL(ctx, tx, sqlStr, values...)
}
func ExecTimeoutSQL(tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
return ExecContextSQL(ctx, tx, sqlStr, values...)
}
/*
ExecTxSQL automatically escapes the given values and caches the statement
Example:
@ -191,6 +198,15 @@ func QueryRowContextSQL(ctx context.Context, tx *sql.Tx, query string, values, o
return gcdb.QueryRowContextSQL(ctx, tx, query, values, out)
}
// QueryRowTimeoutSQL is a helper function for querying a single row with the configured default timeout.
// It creates a context with the default timeout to only be used for this query and then disposed.
// It should only be used by a function that does a single SQL query, otherwise use QueryRowContextSQL
func QueryRowTimeoutSQL(tx *sql.Tx, query string, values, out []any) error {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
return QueryRowContextSQL(ctx, tx, query, values, out)
}
/*
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
@ -253,6 +269,21 @@ func QueryContextSQL(ctx context.Context, tx *sql.Tx, query string, a ...any) (*
return gcdb.QueryContextSQL(ctx, tx, query, a...)
}
// QueryTimeoutSQL creates a new context with the configured default timeout and passes it and
// the given transaction, query, and parameters to QueryContextSQL. If it returns an error,
// the context is cancelled, and the error is returned. Otherwise, it returns the rows,
// cancel function (for the calling function to call later), and nil error. It should only be used
// if the calling function is only doing one SQL query, otherwise use QueryContextSQL.
func QueryTimeoutSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, context.CancelFunc, error) {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
rows, err := QueryContextSQL(ctx, tx, query, a...)
if err != nil {
cancel()
return nil, cancel, err
}
return rows, cancel, nil
}
/*
QueryTxSQL gets all rows from the db using the transaction tx with the values in values[] and fills the
respective pointers in out[]. Automatically escapes the given values and caches the query