1
0
Fork 0
mirror of https://github.com/Eggbertx/gochan.git synced 2025-09-09 18:16:22 -07:00
gochan/pkg/gcsql/connect.go
2024-11-07 12:11:47 -08:00

75 lines
1.8 KiB
Go

package gcsql
import (
"database/sql"
"os"
"regexp"
"strings"
"github.com/gochan-org/gochan/pkg/config"
"github.com/gochan-org/gochan/pkg/gcutil/testutil"
)
// ConnectToDB initializes the database connection and exits if there are any errors
func ConnectToDB(cfg *config.SQLConfig) error {
var err error
gcdb, err = Open(cfg)
return err
}
// SetDB sets the global database connection (mainly used by gochan-migration)
func SetDB(db *GCDB) {
gcdb = db
}
func SetTestingDB(dbDriver string, dbName string, dbPrefix string, db *sql.DB) (err error) {
testutil.PanicIfNotTest()
sqlConfig := config.GetSQLConfig()
if sqlConfig.DBname == "" {
return ErrNotConnected
}
gcdb, err = setupDBConn(&config.SQLConfig{
DBtype: dbDriver,
DBhost: "localhost",
DBname: dbName,
DBusername: "gochan",
DBpassword: "gochan",
DBprefix: dbPrefix,
DBTimeoutSeconds: config.DefaultSQLTimeout,
DBMaxOpenConnections: config.DefaultSQLMaxConns,
DBMaxIdleConnections: config.DefaultSQLMaxConns,
DBConnMaxLifetimeMin: config.DefaultSQLConnMaxLifetimeMin,
})
if err != nil {
return
}
gcdb.db = db
return
}
// RunSQLFile cuts a given sql file into individual statements and runs it.
func RunSQLFile(path string) error {
sqlBytes, err := os.ReadFile(path)
if err != nil {
return err
}
sqlStr := regexp.MustCompile("--.*\n?").ReplaceAllString(string(sqlBytes), " ")
sqlArr := strings.Split(gcdb.replacer.Replace(sqlStr), ";")
tx, err := BeginTx()
if err != nil {
return err
}
defer tx.Rollback()
for _, statement := range sqlArr {
statement = strings.Trim(statement, " \n\r\t")
if len(statement) > 0 {
if _, err = ExecTxSQL(tx, statement); err != nil {
return err
}
}
}
return tx.Commit()
}