From 28963eb813e5d43a99d09f72b741b66c13b20467 Mon Sep 17 00:00:00 2001 From: Eggbertx Date: Thu, 28 Dec 2023 00:36:24 -0800 Subject: [PATCH] Add IP Range parsing, start working on adding range bans --- build.py | 2 +- .../internal/gcupdate/gcupdate.go | 51 +++++++++++- pkg/gcsql/provisioning.go | 2 +- pkg/gcsql/tables.go | 2 + pkg/gcutil/util.go | 78 +++++++++++++++++++ sql/initdb_master.sql | 5 +- 6 files changed, 134 insertions(+), 6 deletions(-) diff --git a/build.py b/build.py index c59f9cae..7220c5e5 100755 --- a/build.py +++ b/build.py @@ -39,7 +39,7 @@ release_files = ( ) GOCHAN_VERSION = "3.9.0" -DATABASE_VERSION = "2" # stored in DBNAME.DBPREFIXdatabase_version +DATABASE_VERSION = "3" # stored in DBNAME.DBPREFIXdatabase_version PATH_NOTHING = -1 PATH_UNKNOWN = 0 diff --git a/cmd/gochan-migration/internal/gcupdate/gcupdate.go b/cmd/gochan-migration/internal/gcupdate/gcupdate.go index f684efff..da9c1f9d 100644 --- a/cmd/gochan-migration/internal/gcupdate/gcupdate.go +++ b/cmd/gochan-migration/internal/gcupdate/gcupdate.go @@ -65,7 +65,9 @@ func (dbu *GCDatabaseUpdater) MigrateDB() (bool, error) { if err != nil { return false, err } - defer tx.Rollback() + defer func() { + tx.Rollback() + }() switch criticalConfig.DBtype { case "mysql": @@ -108,7 +110,9 @@ func (dbu *GCDatabaseUpdater) MigrateDB() (bool, error) { if err != nil { return false, err } - defer rows.Close() + defer func() { + rows.Close() + }() for rows.Next() { var tableName string err = rows.Scan(&tableName) @@ -120,6 +124,49 @@ func (dbu *GCDatabaseUpdater) MigrateDB() (bool, error) { return false, err } } + if err = rows.Close(); err != nil { + return false, err + } + query = `SELECT COUNT(*) FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'DBPREFIXip_ban' + AND COLUMN_NAME = 'ip'` + if err = dbu.db.QueryRowTxSQL(tx, query, nil, []any{&numColumns}); err != nil { + return false, err + } + if numColumns > 0 { + // add range_start and range_end columns + query = `ALTER TABLE DBPREFIXip_ban + ADD COLUMN IF NOT EXISTS range_start VARBINARY(16) NOT NULL + ADD COLUMN IF NOT EXISTS range_end VARBINARY(16) NOT NULL` + if _, err = gcsql.ExecTxSQL(tx, query); err != nil { + return false, err + } + // convert string to IP range + if rows, err = dbu.db.QuerySQL(`SELECT id, ip FROM DBPREFIXip_ban`); err != nil { + return false, err + } + var rangeStart string + var rangeEnd string + for rows.Next() { + var id int + var ipOrCIDR string + if err = rows.Scan(&id, &ipOrCIDR); err != nil { + return false, err + } + if rangeStart, rangeEnd, err = gcutil.ParseIPRange(ipOrCIDR); err != nil { + return false, err + } + query = `UPDATE DBPREFIXip_ban SET range_start = INET6_ATON(?), range_end = ? WHERE id = ?` + if _, err = gcsql.ExecTxSQL(tx, query, rangeStart, rangeEnd, id); err != nil { + return false, err + } + query = `ALTER TABLE DBPREFIXip_ban DROP COLUMN ip` + if _, err = gcsql.ExecTxSQL(tx, query); err != nil { + return false, err + } + } + } err = nil case "postgres": _, err = gcsql.ExecSQL(`ALTER TABLE DBPREFIXwordfilters DROP CONSTRAINT IF EXISTS board_id_fk`) diff --git a/pkg/gcsql/provisioning.go b/pkg/gcsql/provisioning.go index 776efd18..da9b507e 100644 --- a/pkg/gcsql/provisioning.go +++ b/pkg/gcsql/provisioning.go @@ -16,7 +16,7 @@ const ( DBUpToDate DBModernButAhead - targetDatabaseVersion = 2 + targetDatabaseVersion = 3 ) var ( diff --git a/pkg/gcsql/tables.go b/pkg/gcsql/tables.go index 32ff83f0..158ebfca 100644 --- a/pkg/gcsql/tables.go +++ b/pkg/gcsql/tables.go @@ -114,6 +114,8 @@ type IPBan struct { BannedForPostID *int CopyPostText template.HTML IP string + IPRangeStart string + IPRangeEnd string IssuedAt time.Time ipBanBase } diff --git a/pkg/gcutil/util.go b/pkg/gcutil/util.go index af71b01f..67fb26b8 100644 --- a/pkg/gcutil/util.go +++ b/pkg/gcutil/util.go @@ -10,6 +10,7 @@ import ( "math/rand" "net" "net/http" + "net/netip" "os" "path" "path/filepath" @@ -157,6 +158,83 @@ func MarshalJSON(data interface{}, indent bool) (string, error) { return string(jsonBytes), err } +// ParseIPRange takes a single IP address or an IP range of the form "networkIP/netmaskbits" and +// gives the starting IP and ending IP in the subnet +// +// More info: https://en.wikipedia.org/wiki/Subnet +func ParseIPRange(ipOrCIDR string) (string, string, error) { + var ipStart netip.Addr + + if strings.ContainsRune(ipOrCIDR, '/') { + var ipEnd netip.Addr + // CIDR range + prefix, err := netip.ParsePrefix(ipOrCIDR) + if err != nil { + return "", "", err + } + ipStart = prefix.Addr() + ipEnd = prefix.Addr() + var tmp netip.Addr + for { + tmp = ipEnd.Next() + if !prefix.Contains(tmp) { + break + } + ipEnd = tmp + } + return ipStart.String(), ipEnd.String(), nil + } + // single IP + var err error + if ipStart, err = netip.ParseAddr(ipOrCIDR); err != nil { + return "", "", err + } + return ipStart.String(), ipStart.String(), nil +} + +// GetIPRangeString returns an IP address if start == end, or the subnet of all IP +// addresses between start and end +func GetIPRangeString(start string, end string) (string, error) { + if start == end { + return start, nil + } + startIP := net.ParseIP(start) + endIP := net.ParseIP(end) + if startIP == nil { + return "", fmt.Errorf("invalid IP address %s", start) + } + if endIP == nil { + return "", fmt.Errorf("invalid IP address %s", end) + } + if len(startIP) != len(endIP) { + return "", errors.New("ip addresses must both be IPv4 or IPv6") + } + + if startIP.To4() != nil { + startIP = startIP.To4() + endIP = endIP.To4() + } + + bits := 0 + var ipn net.IPNet + for b := range startIP { + if startIP[b] == endIP[b] { + bits += 8 + continue + } + for i := 7; i >= 0; i-- { + if startIP[b]&(1<