mirror of
https://github.com/Eggbertx/gochan.git
synced 2025-09-05 11:06:23 -07:00
Fix SQL string replacer messing with prepared statements in Postgres
This commit is contained in:
parent
121959fa15
commit
d294462968
3 changed files with 37 additions and 30 deletions
|
@ -54,14 +54,14 @@ func (db *GCDB) Close() error {
|
|||
func (db *GCDB) PrepareSQL(query string, tx *sql.Tx) (*sql.Stmt, error) {
|
||||
var prepared string
|
||||
var err error
|
||||
if prepared, err = SetupSQLString(query, db); err != nil {
|
||||
if prepared, err = SetupSQLString(db.replacer.Replace(query), db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var stmt *sql.Stmt
|
||||
if tx != nil {
|
||||
stmt, err = tx.Prepare(db.replacer.Replace(prepared))
|
||||
stmt, err = tx.Prepare(prepared)
|
||||
} else {
|
||||
stmt, err = db.db.Prepare(db.replacer.Replace(prepared))
|
||||
stmt, err = db.db.Prepare(prepared)
|
||||
}
|
||||
if err != nil {
|
||||
return stmt, err
|
||||
|
@ -221,10 +221,15 @@ func Open(host, dbDriver, dbName, username, password, prefix string) (db *GCDB,
|
|||
db = &GCDB{
|
||||
driver: dbDriver,
|
||||
}
|
||||
if dbDriver == "mysql" {
|
||||
db.replacer = strings.NewReplacer(
|
||||
"DBNAME", dbName,
|
||||
"DBPREFIX", prefix,
|
||||
replacerArr := []string{
|
||||
"DBNAME", dbName,
|
||||
"DBPREFIX", prefix,
|
||||
"\n", " ",
|
||||
}
|
||||
switch dbDriver {
|
||||
case "mysql":
|
||||
db.connStr = fmt.Sprintf(mysqlConnStr, username, password, host, dbName)
|
||||
replacerArr = append(replacerArr,
|
||||
"RANGE_START_ATON", "INET6_ATON(range_start)",
|
||||
"RANGE_START_NTOA", "INET6_NTOA(range_start)",
|
||||
"RANGE_END_ATON", "INET6_ATON(range_end)",
|
||||
|
@ -233,11 +238,10 @@ func Open(host, dbDriver, dbName, username, password, prefix string) (db *GCDB,
|
|||
"IP_NTOA", "INET6_NTOA(ip)",
|
||||
"PARAM_ATON", "INET6_ATON(?)",
|
||||
"PARAM_NTOA", "INET6_NTOA(?)",
|
||||
"\n", " ")
|
||||
} else {
|
||||
db.replacer = strings.NewReplacer(
|
||||
"DBNAME", dbName,
|
||||
"DBPREFIX", prefix,
|
||||
)
|
||||
case "postgres":
|
||||
db.connStr = fmt.Sprintf(postgresConnStr, username, password, host, dbName)
|
||||
replacerArr = append(replacerArr,
|
||||
"RANGE_START_ATON", "range_start",
|
||||
"RANGE_START_NTOA", "range_start",
|
||||
"RANGE_END_ATON", "range_end",
|
||||
|
@ -246,26 +250,27 @@ func Open(host, dbDriver, dbName, username, password, prefix string) (db *GCDB,
|
|||
"IP_NTOA", "ip",
|
||||
"PARAM_ATON", "?",
|
||||
"PARAM_NTOA", "?",
|
||||
"\n", " ")
|
||||
}
|
||||
|
||||
if dbDriver != "sqlite3" {
|
||||
)
|
||||
case "sqlite3":
|
||||
addrMatches := tcpHostIsolator.FindAllStringSubmatch(host, -1)
|
||||
if len(addrMatches) > 0 && len(addrMatches[0]) > 2 {
|
||||
host = addrMatches[0][2]
|
||||
}
|
||||
}
|
||||
|
||||
switch dbDriver {
|
||||
case "mysql":
|
||||
db.connStr = fmt.Sprintf(mysqlConnStr, username, password, host, dbName)
|
||||
case "sqlite3":
|
||||
db.connStr = fmt.Sprintf(sqlite3ConnStr, host, username, password)
|
||||
case "postgres":
|
||||
db.connStr = fmt.Sprintf(postgresConnStr, username, password, host, dbName)
|
||||
replacerArr = append(replacerArr,
|
||||
"RANGE_START_ATON", "range_start",
|
||||
"RANGE_START_NTOA", "range_start",
|
||||
"RANGE_END_ATON", "range_end",
|
||||
"RANGE_END_NTOA", "range_end",
|
||||
"IP_ATON", "ip",
|
||||
"IP_NTOA", "ip",
|
||||
"PARAM_ATON", "?",
|
||||
"PARAM_NTOA", "?",
|
||||
)
|
||||
default:
|
||||
return nil, ErrUnsupportedDB
|
||||
}
|
||||
db.replacer = strings.NewReplacer(replacerArr...)
|
||||
db.db, err = sql.Open(db.driver, db.connStr)
|
||||
if err != nil {
|
||||
db.db.SetConnMaxLifetime(time.Minute * 3)
|
||||
|
|
|
@ -155,6 +155,9 @@ var funcMap = template.FuncMap{
|
|||
},
|
||||
"banMask": func(ban gcsql.IPBan) string {
|
||||
if ban.ID < 1 {
|
||||
if ban.RangeStart == ban.RangeEnd {
|
||||
return ban.RangeStart
|
||||
}
|
||||
return ""
|
||||
}
|
||||
ipn, err := gcutil.GetIPRangeSubnet(ban.RangeStart, ban.RangeEnd)
|
||||
|
|
|
@ -47,11 +47,11 @@ func showBanpage(ban *gcsql.IPBan, post *gcsql.Post, postBoard *gcsql.Board, wri
|
|||
func checkIpBan(post *gcsql.Post, postBoard *gcsql.Board, writer http.ResponseWriter, request *http.Request) bool {
|
||||
ipBan, err := gcsql.CheckIPBan(post.IP, postBoard.ID)
|
||||
if err != nil {
|
||||
gcutil.LogError(err).
|
||||
gcutil.LogError(err).Caller().
|
||||
Str("IP", post.IP).
|
||||
Str("boardDir", postBoard.Dir).
|
||||
Msg("Error getting IP banned status")
|
||||
server.ServeErrorPage(writer, "Error getting ban info"+err.Error())
|
||||
server.ServeErrorPage(writer, "Error checking banned status: "+err.Error())
|
||||
return true
|
||||
}
|
||||
if ipBan == nil {
|
||||
|
@ -73,7 +73,7 @@ func checkUsernameBan(post *gcsql.Post, postBoard *gcsql.Board, writer http.Resp
|
|||
|
||||
nameBan, err := gcsql.CheckNameBan(nameTrip, postBoard.ID)
|
||||
if err != nil {
|
||||
gcutil.LogError(err).
|
||||
gcutil.LogError(err).Caller().
|
||||
Str("IP", post.IP).
|
||||
Str("nameTrip", nameTrip).
|
||||
Str("boardDir", postBoard.Dir).
|
||||
|
@ -119,13 +119,12 @@ func handleAppeal(writer http.ResponseWriter, request *http.Request, infoEv *zer
|
|||
|
||||
ban, err := gcsql.GetIPBanByID(banID)
|
||||
if err != nil {
|
||||
errEv.Err(err).
|
||||
Caller().Send()
|
||||
errEv.Err(err).Caller().Send()
|
||||
server.ServeErrorPage(writer, "Error getting ban info: "+err.Error())
|
||||
return
|
||||
}
|
||||
if ban == nil {
|
||||
errEv.Caller().Msg("GetIPBanByID returned a nil ban (presumably not banned)")
|
||||
infoEv.Caller().Msg("GetIPBanByID returned a nil ban (presumably not banned)")
|
||||
server.ServeErrorPage(writer, fmt.Sprintf("Invalid ban ID %d", banID))
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue