1
0
Fork 0
mirror of https://github.com/Eggbertx/gochan.git synced 2025-08-02 23:26:23 -07:00
gochan/pkg/gcsql/database.go
2024-05-24 16:10:07 -07:00

487 lines
13 KiB
Go

package gcsql
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/gochan-org/gochan/pkg/config"
)
const (
// GochanVersionKeyConstant is the key value used in the version table of the database to store and receive the (database) version of base gochan
gochanVersionKeyConstant = "gochan"
UnsupportedSQLVersionMsg = `syntax error in SQL query, confirm you are using a supported driver and SQL server (error text: %s)`
mysqlConnStr = "%s:%s@tcp(%s)/%s?parseTime=true&collation=utf8mb4_unicode_ci"
postgresConnStr = "postgres://%s:%s@%s/%s?sslmode=disable"
sqlite3ConnStr = "file:%s?_auth&_auth_user=%s&_auth_pass=%s&_auth_crypt=sha1"
)
var (
gcdb *GCDB
mysqlReplacerArr = []string{
"RANGE_START_ATON", "INET6_ATON(range_start)",
"RANGE_START_NTOA", "INET6_NTOA(range_start)",
"RANGE_END_ATON", "INET6_ATON(range_end)",
"RANGE_END_NTOA", "INET6_NTOA(range_end)",
"IP_ATON", "INET6_ATON(ip)",
"IP_NTOA", "INET6_NTOA(ip)",
"PARAM_ATON", "INET6_ATON(?)",
"PARAM_NTOA", "INET6_NTOA(?)",
}
postgresReplacerArr = []string{
"RANGE_START_ATON", "range_start",
"RANGE_START_NTOA", "range_start",
"RANGE_END_ATON", "range_end",
"RANGE_END_NTOA", "range_end",
"IP_ATON", "ip",
"IP_NTOA", "ip",
"PARAM_ATON", "?",
"PARAM_NTOA", "?",
}
sqlite3ReplacerArr = []string{
"RANGE_START_ATON", "range_start",
"RANGE_START_NTOA", "range_start",
"RANGE_END_ATON", "range_end",
"RANGE_END_NTOA", "range_end",
"IP_ATON", "ip",
"IP_NTOA", "ip",
"PARAM_ATON", "?",
"PARAM_NTOA", "?",
}
)
type GCDB struct {
db *sql.DB
connStr string
driver string
defaultTimeout time.Duration
replacer *strings.Replacer
}
func (db *GCDB) ConnectionString() string {
return db.connStr
}
func (db *GCDB) Connection() *sql.DB {
return db.db
}
func (db *GCDB) SQLDriver() string {
return db.driver
}
func (db *GCDB) Close() error {
if db.db != nil {
return db.db.Close()
}
return nil
}
func (db *GCDB) PrepareSQL(query string, tx *sql.Tx) (*sql.Stmt, error) {
return db.PrepareContextSQL(context.Background(), query, tx)
}
func (db *GCDB) PrepareContextSQL(ctx context.Context, query string, tx *sql.Tx) (*sql.Stmt, error) {
var prepared string
var err error
if prepared, err = SetupSQLString(db.replacer.Replace(query), db); err != nil {
return nil, err
}
if ctx == nil {
ctx = context.Background()
}
_, hasDeadline := ctx.Deadline()
if !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), db.defaultTimeout)
defer cancel()
}
var stmt *sql.Stmt
if tx != nil {
stmt, err = tx.PrepareContext(ctx, prepared)
} else {
stmt, err = db.db.PrepareContext(ctx, prepared)
}
if err != nil {
return stmt, err
}
return stmt, sqlVersionError(err, db.driver, &prepared)
}
/*
ExecSQL executes the given SQL statement with the given parameters
Example:
var intVal int
var stringVal string
result, err := db.ExecSQL("INSERT INTO tablename (intval,stringval) VALUES(?,?)", intVal, stringVal)
*/
func (db *GCDB) ExecSQL(query string, values ...any) (sql.Result, error) {
stmt, err := db.PrepareSQL(query, nil)
if err != nil {
return nil, err
}
defer stmt.Close()
result, err := stmt.Exec(values...)
if err != nil {
return nil, err
}
return result, stmt.Close()
}
/*
ExecContextSQL executes the given SQL statement with the given context, optionally with the given transaction (if non-nil)
Example:
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(sqlCfg.DBTimeoutSeconds) * time.Second)
defer cancel()
var intVal int
var stringVal string
result, err := db.ExecContextSQL(ctx, nil, "INSERT INTO tablename (intval,stringval) VALUES(?,?)",
intVal, stringVal)
*/
func (db *GCDB) ExecContextSQL(ctx context.Context, tx *sql.Tx, sqlStr string, values ...any) (sql.Result, error) {
stmt, err := db.PrepareContextSQL(ctx, sqlStr, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
result, err := stmt.ExecContext(ctx, values...)
if err != nil {
return nil, err
}
return result, stmt.Close()
}
/*
ExecTxSQL executes the given SQL statemtnt, optionally with the given transaction (if non-nil)
Example:
tx, err := BeginTx()
// do error handling stuff
defer tx.Rollback()
var intVal int
var stringVal string
result, err := db.ExecTxSQL(tx, "INSERT INTO tablename (intval,stringval) VALUES(?,?)",
intVal, stringVal)
*/
func (db *GCDB) ExecTxSQL(tx *sql.Tx, query string, values ...any) (sql.Result, error) {
stmt, err := db.PrepareSQL(query, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
res, err := stmt.Exec(values...)
if err != nil {
return res, err
}
return res, stmt.Close()
}
/*
Begin creates and returns a new SQL transaction using the GCDB. Note that it doesn't use gochan's
database variables, e.g. DBPREFIX, DBNAME, etc so it should be used sparingly or with
gcsql.SetupSQLString
*/
func (db *GCDB) Begin() (*sql.Tx, error) {
return db.db.Begin()
}
/*
BeginTx creates and returns a new SQL transaction using the GCDB with the specified context
and transaction options. Note that it doesn't use gochan's database variables, e.g. DBPREFIX,
DBNAME, etc so it should be used sparingly or with gcsql.SetupSQLString
*/
func (db *GCDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return db.db.BeginTx(ctx, opts)
}
/*
QueryRowSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
Example:
id := 32
var intVal int
var stringVal string
err := db.QueryRowSQL("SELECT intval,stringval FROM table WHERE id = ?",
[]any{id},
[]any{&intVal, &stringVal})
*/
func (db *GCDB) QueryRowSQL(query string, values, out []any) error {
stmt, err := db.PrepareSQL(query, nil)
if err != nil {
return err
}
defer stmt.Close()
return stmt.QueryRow(values...).Scan(out...)
}
/*
QueryRowContextSQL gets a row from the database with the values in values[] and fills the respective pointers in out[]
using the given context as a deadline, and the given transaction (if non-nil)
Example:
id := 32
var name string
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(sqlCfg.DBTimeoutSeconds) * time.Second)
defer cancel()
err := db.QueryRowContextSQL(ctx, nil, "SELECT name FROM DBPREFIXposts WHERE id = ? LIMIT 1",
[]any{id}, []any{&name})
*/
func (db *GCDB) QueryRowContextSQL(ctx context.Context, tx *sql.Tx, query string, values, out []any) error {
stmt, err := db.PrepareContextSQL(ctx, query, tx)
if err != nil {
return err
}
defer stmt.Close()
if err = stmt.QueryRowContext(ctx, values...).Scan(out...); err != nil {
return err
}
return stmt.Close()
}
/*
QueryRowTxSQL gets a row from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
Example:
id := 32
var intVal int
var stringVal string
tx, err := BeginTx()
// do error handling stuff
defer tx.Rollback()
err = QueryRowTxSQL(tx, "SELECT intval,stringval FROM table WHERE id = ?",
[]any{id},
[]any{&intVal, &stringVal})
*/
func (db *GCDB) QueryRowTxSQL(tx *sql.Tx, query string, values, out []any) error {
stmt, err := db.PrepareSQL(query, tx)
if err != nil {
return err
}
defer stmt.Close()
if err = stmt.QueryRow(values...).Scan(out...); err != nil {
return err
}
return stmt.Close()
}
/*
QuerySQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
Example:
rows, err := db.QuerySQL("SELECT * FROM table")
if err == nil {
for rows.Next() {
var intVal int
var stringVal string
rows.Scan(&intVal, &stringVal)
// do something with intVal and stringVal
}
}
*/
func (db *GCDB) QuerySQL(query string, a ...any) (*sql.Rows, error) {
stmt, err := db.PrepareSQL(query, nil)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(a...)
}
/*
QueryContextSQL queries the database with a prepared statement and the given parameters, using the given context
for a deadline
Example:
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(sqlCfg.DBTimeoutSeconds) * time.Second)
defer cancel()
rows, err := db.QueryContextSQL(ctx, nil, "SELECT name from posts where NOT is_deleted")
*/
func (db *GCDB) QueryContextSQL(ctx context.Context, tx *sql.Tx, query string, a ...any) (*sql.Rows, error) {
stmt, err := db.PrepareContextSQL(ctx, query, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx, a...)
if err != nil {
return rows, err
}
return rows, stmt.Close()
}
/*
QueryTxSQL gets all rows from the db with the values in values[] and fills the respective pointers in out[]
Automatically escapes the given values and caches the query
Example:
tx, _ := db.Begin()
rows, err := db.QueryTxSQL(tx, "SELECT * FROM table")
if err == nil {
for rows.Next() {
var intVal int
var stringVal string
rows.Scan(&intVal, &stringVal)
// do something with intVal and stringVal
}
}
*/
func (db *GCDB) QueryTxSQL(tx *sql.Tx, query string, a ...any) (*sql.Rows, error) {
stmt, err := db.PrepareSQL(query, tx)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(a...)
}
func setupDBConn(cfg *config.SQLConfig) (db *GCDB, err error) {
db = &GCDB{
driver: cfg.DBtype,
defaultTimeout: time.Duration(cfg.DBTimeoutSeconds) * time.Second,
}
replacerArr := []string{
"DBNAME", cfg.DBname,
"DBPREFIX", cfg.DBprefix,
"\n", " ",
}
switch cfg.DBtype {
case "mysql":
db.connStr = fmt.Sprintf(mysqlConnStr, cfg.DBusername, cfg.DBpassword, cfg.DBhost, cfg.DBname)
replacerArr = append(replacerArr, mysqlReplacerArr...)
case "postgres":
db.connStr = fmt.Sprintf(postgresConnStr, cfg.DBusername, cfg.DBpassword, cfg.DBhost, cfg.DBname)
replacerArr = append(replacerArr, postgresReplacerArr...)
case "sqlite3":
addrMatches := tcpHostIsolator.FindAllStringSubmatch(cfg.DBhost, -1)
if len(addrMatches) > 0 && len(addrMatches[0]) > 2 {
cfg.DBhost = addrMatches[0][2]
}
db.connStr = fmt.Sprintf(sqlite3ConnStr, cfg.DBhost, cfg.DBusername, cfg.DBpassword)
replacerArr = append(replacerArr, sqlite3ReplacerArr...)
default:
return nil, ErrUnsupportedDB
}
db.replacer = strings.NewReplacer(replacerArr...)
return db, nil
}
// Open opens and returns a new gochan database connection with the provided host, driver, DB name,
// username, password, and table prefix
func Open(cfg *config.SQLConfig) (db *GCDB, err error) {
db, err = setupDBConn(cfg)
if err != nil {
return nil, err
}
db.db, err = sql.Open(db.driver, db.connStr)
if err != nil {
db.db.SetConnMaxLifetime(time.Minute * time.Duration(cfg.DBConnMaxLifetimeMin))
db.db.SetMaxOpenConns(cfg.DBMaxOpenConnections)
db.db.SetMaxIdleConns(cfg.DBMaxIdleConnections)
}
return db, err
}
func optimizeMySQL() error {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
var wg sync.WaitGroup
rows, err := QueryContextSQL(ctx, nil, "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE()")
if err != nil {
return err
}
for rows.Next() {
wg.Add(1)
var table string
if err = rows.Scan(&table); err != nil {
rows.Close()
return err
}
go func(table string) {
if _, err = ExecContextSQL(ctx, nil, "OPTIMIZE TABLE "+table); err != nil {
rows.Close()
return
}
wg.Done()
}(table)
}
wg.Wait()
if err != nil {
return err
}
return rows.Close()
}
func optimizePostgres() error {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
_, err := ExecContextSQL(ctx, nil, "REINDEX DATABASE "+config.GetSQLConfig().DBname)
return err
}
func optimizeSqlite3() error {
ctx, cancel := context.WithTimeout(context.Background(), gcdb.defaultTimeout)
defer cancel()
_, err := ExecContextSQL(ctx, nil, "VACUUM")
return err
}
// OptimizeDatabase peforms a database optimisation
func OptimizeDatabase() error {
switch config.GetSQLConfig().DBtype {
case "mysql":
return optimizeMySQL()
case "postgresql":
return optimizePostgres()
case "sqlite3":
return optimizeSqlite3()
default:
// this shouldn't happen under normal circumstances since this is assumed to have already been checked
return ErrUnsupportedDB
}
}
func sqlVersionError(err error, dbDriver string, query *string) error {
if err == nil {
return nil
}
errText := err.Error()
switch dbDriver {
case "mysql":
if !strings.Contains(errText, "You have an error in your SQL syntax") {
return err
}
case "postgres":
if !strings.Contains(errText, "syntax error at or near") {
return err
}
}
if config.GetSystemCriticalConfig().Verbose {
return fmt.Errorf(UnsupportedSQLVersionMsg+"\nQuery: "+*query, errText)
}
return fmt.Errorf(UnsupportedSQLVersionMsg, errText)
}