1
0
Fork 0
mirror of https://github.com/Eggbertx/gochan.git synced 2025-09-05 11:06:23 -07:00

Fix SQL string replacer messing with prepared statements in Postgres

This commit is contained in:
Eggbertx 2024-01-02 17:02:42 -08:00
parent 121959fa15
commit d294462968
3 changed files with 37 additions and 30 deletions

View file

@ -54,14 +54,14 @@ func (db *GCDB) Close() error {
func (db *GCDB) PrepareSQL(query string, tx *sql.Tx) (*sql.Stmt, error) {
var prepared string
var err error
if prepared, err = SetupSQLString(query, db); err != nil {
if prepared, err = SetupSQLString(db.replacer.Replace(query), db); err != nil {
return nil, err
}
var stmt *sql.Stmt
if tx != nil {
stmt, err = tx.Prepare(db.replacer.Replace(prepared))
stmt, err = tx.Prepare(prepared)
} else {
stmt, err = db.db.Prepare(db.replacer.Replace(prepared))
stmt, err = db.db.Prepare(prepared)
}
if err != nil {
return stmt, err
@ -221,10 +221,15 @@ func Open(host, dbDriver, dbName, username, password, prefix string) (db *GCDB,
db = &GCDB{
driver: dbDriver,
}
if dbDriver == "mysql" {
db.replacer = strings.NewReplacer(
"DBNAME", dbName,
"DBPREFIX", prefix,
replacerArr := []string{
"DBNAME", dbName,
"DBPREFIX", prefix,
"\n", " ",
}
switch dbDriver {
case "mysql":
db.connStr = fmt.Sprintf(mysqlConnStr, username, password, host, dbName)
replacerArr = append(replacerArr,
"RANGE_START_ATON", "INET6_ATON(range_start)",
"RANGE_START_NTOA", "INET6_NTOA(range_start)",
"RANGE_END_ATON", "INET6_ATON(range_end)",
@ -233,11 +238,10 @@ func Open(host, dbDriver, dbName, username, password, prefix string) (db *GCDB,
"IP_NTOA", "INET6_NTOA(ip)",
"PARAM_ATON", "INET6_ATON(?)",
"PARAM_NTOA", "INET6_NTOA(?)",
"\n", " ")
} else {
db.replacer = strings.NewReplacer(
"DBNAME", dbName,
"DBPREFIX", prefix,
)
case "postgres":
db.connStr = fmt.Sprintf(postgresConnStr, username, password, host, dbName)
replacerArr = append(replacerArr,
"RANGE_START_ATON", "range_start",
"RANGE_START_NTOA", "range_start",
"RANGE_END_ATON", "range_end",
@ -246,26 +250,27 @@ func Open(host, dbDriver, dbName, username, password, prefix string) (db *GCDB,
"IP_NTOA", "ip",
"PARAM_ATON", "?",
"PARAM_NTOA", "?",
"\n", " ")
}
if dbDriver != "sqlite3" {
)
case "sqlite3":
addrMatches := tcpHostIsolator.FindAllStringSubmatch(host, -1)
if len(addrMatches) > 0 && len(addrMatches[0]) > 2 {
host = addrMatches[0][2]
}
}
switch dbDriver {
case "mysql":
db.connStr = fmt.Sprintf(mysqlConnStr, username, password, host, dbName)
case "sqlite3":
db.connStr = fmt.Sprintf(sqlite3ConnStr, host, username, password)
case "postgres":
db.connStr = fmt.Sprintf(postgresConnStr, username, password, host, dbName)
replacerArr = append(replacerArr,
"RANGE_START_ATON", "range_start",
"RANGE_START_NTOA", "range_start",
"RANGE_END_ATON", "range_end",
"RANGE_END_NTOA", "range_end",
"IP_ATON", "ip",
"IP_NTOA", "ip",
"PARAM_ATON", "?",
"PARAM_NTOA", "?",
)
default:
return nil, ErrUnsupportedDB
}
db.replacer = strings.NewReplacer(replacerArr...)
db.db, err = sql.Open(db.driver, db.connStr)
if err != nil {
db.db.SetConnMaxLifetime(time.Minute * 3)

View file

@ -155,6 +155,9 @@ var funcMap = template.FuncMap{
},
"banMask": func(ban gcsql.IPBan) string {
if ban.ID < 1 {
if ban.RangeStart == ban.RangeEnd {
return ban.RangeStart
}
return ""
}
ipn, err := gcutil.GetIPRangeSubnet(ban.RangeStart, ban.RangeEnd)

View file

@ -47,11 +47,11 @@ func showBanpage(ban *gcsql.IPBan, post *gcsql.Post, postBoard *gcsql.Board, wri
func checkIpBan(post *gcsql.Post, postBoard *gcsql.Board, writer http.ResponseWriter, request *http.Request) bool {
ipBan, err := gcsql.CheckIPBan(post.IP, postBoard.ID)
if err != nil {
gcutil.LogError(err).
gcutil.LogError(err).Caller().
Str("IP", post.IP).
Str("boardDir", postBoard.Dir).
Msg("Error getting IP banned status")
server.ServeErrorPage(writer, "Error getting ban info"+err.Error())
server.ServeErrorPage(writer, "Error checking banned status: "+err.Error())
return true
}
if ipBan == nil {
@ -73,7 +73,7 @@ func checkUsernameBan(post *gcsql.Post, postBoard *gcsql.Board, writer http.Resp
nameBan, err := gcsql.CheckNameBan(nameTrip, postBoard.ID)
if err != nil {
gcutil.LogError(err).
gcutil.LogError(err).Caller().
Str("IP", post.IP).
Str("nameTrip", nameTrip).
Str("boardDir", postBoard.Dir).
@ -119,13 +119,12 @@ func handleAppeal(writer http.ResponseWriter, request *http.Request, infoEv *zer
ban, err := gcsql.GetIPBanByID(banID)
if err != nil {
errEv.Err(err).
Caller().Send()
errEv.Err(err).Caller().Send()
server.ServeErrorPage(writer, "Error getting ban info: "+err.Error())
return
}
if ban == nil {
errEv.Caller().Msg("GetIPBanByID returned a nil ban (presumably not banned)")
infoEv.Caller().Msg("GetIPBanByID returned a nil ban (presumably not banned)")
server.ServeErrorPage(writer, fmt.Sprintf("Invalid ban ID %d", banID))
return
}