From 45ba890280a08a6c37944b57193c3a092b004324 Mon Sep 17 00:00:00 2001 From: Eggbertx Date: Sat, 30 Mar 2024 17:25:19 -0700 Subject: [PATCH] Refactor gcsql tests to clean them up a bit --- pkg/gcsql/appeals_test.go | 261 ++++++++++++++++----------------- pkg/gcsql/provisioning_test.go | 6 +- pkg/gcsql/setup_test.go | 18 ++- 3 files changed, 148 insertions(+), 137 deletions(-) diff --git a/pkg/gcsql/appeals_test.go b/pkg/gcsql/appeals_test.go index a1cd0580..a18fbce3 100644 --- a/pkg/gcsql/appeals_test.go +++ b/pkg/gcsql/appeals_test.go @@ -7,24 +7,11 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" - "github.com/gochan-org/gochan/pkg/config" "github.com/stretchr/testify/assert" ) -type argsGetAppeals struct { - banID int - limit int -} - -func TestGetAppeals(t *testing.T) { - config.SetVersion("3.10.1") - config.SetRandomSeed("test") - - testCases := []struct { - name string - args argsGetAppeals - expectReturn []IPBanAppeal - }{ +var ( + testCasesGetAppeals = []testCaseGetAppeals{ { name: "single appeal, no results", args: argsGetAppeals{1, 1}, @@ -48,66 +35,28 @@ func TestGetAppeals(t *testing.T) { expectReturn: []IPBanAppeal{{}, {}, {}}, }, } - var mock sqlmock.Sqlmock - var err error - for _, tC := range testCases { - for _, driver := range testingDBDrivers { - t.Run(fmt.Sprintf("%s (%s)", tC.name, driver), func(t *testing.T) { - gcdb, err = setupDBConn("localhost", driver, "gochan", "gochan", "gochan", "") - if !assert.NoError(t, err) { - return - } - gcdb.db, mock, err = sqlmock.New() - if !assert.NoError(t, err) { - return - } - - query := `SELECT id, staff_id, ip_ban_id, appeal_text, staff_response, is_denied FROM ip_ban_appeals` - if tC.args.banID > 0 { - switch driver { - case "mysql": - query += ` WHERE ip_ban_id = \?` - case "sqlite3": - fallthrough - case "postgres": - query += ` WHERE ip_ban_id = \$1` - } - } - if tC.args.limit > 0 { - query += " LIMIT " + strconv.Itoa(tC.args.limit) - } - expectQuery := mock.ExpectPrepare(query).ExpectQuery() - if tC.args.banID > 0 { - expectQuery.WithArgs(tC.args.banID) - } - - expectedRows := sqlmock.NewRows([]string{"id", "staff_id", "ip_ban_id", "appeal_text", "staff_response", "is_denied"}) - if len(tC.expectReturn) > 0 { - for _, expectedBan := range tC.expectReturn { - expectedRows.AddRow( - expectedBan.ID, expectedBan.StaffID, expectedBan.IPBanID, expectedBan.AppealText, - expectedBan.StaffResponse, expectedBan.IsDenied, - ) - } - } - expectQuery.WillReturnRows(expectedRows) - - got, err := GetAppeals(tC.args.banID, tC.args.limit) - if !assert.NoError(t, err) { - return - } - assert.NoError(t, mock.ExpectationsWereMet()) - - assert.LessOrEqual(t, len(got), tC.args.limit) - assert.Equal(t, tC.expectReturn, got) - if tC.args.banID > 0 && tC.expectReturn != nil { - assert.Equal(t, tC.args.banID, tC.expectReturn[0].ID) - } - assert.NoError(t, mock.ExpectationsWereMet()) - closeMock(t, mock) - }) - } + testCasesApproveAppeals = []testCaseApproveAppeals{ + { + name: "approve nonexistent appeal", + args: argsApproveAppeal{1, 1}, + }, } +) + +type testCaseGetAppeals struct { + name string + args argsGetAppeals + expectReturn []IPBanAppeal +} + +type testCaseApproveAppeals struct { + name string + args argsApproveAppeal +} + +type argsGetAppeals struct { + banID int + limit int } type argsApproveAppeal struct { @@ -115,71 +64,119 @@ type argsApproveAppeal struct { staffID int } -func TestApproveAppeal(t *testing.T) { - tests := []struct { - name string - args argsApproveAppeal - expectsAffectedRows bool - }{ - { - name: "approve nonexistent appeal", - args: argsApproveAppeal{1, 1}, - }, +func testRunnerGetAppeals(t *testing.T, tC *testCaseGetAppeals, driver string) { + t.Helper() + mock, err := setupMockDB(t, driver, "gochan") + if !assert.NoError(t, err) { + return } - var mock sqlmock.Sqlmock - var err error - for _, tC := range tests { - for _, sqlDriver := range testingDBDrivers { - t.Run(fmt.Sprintf("%s (%s)", tC.name, sqlDriver), func(t *testing.T) { - gcdb, err = setupDBConn("localhost", sqlDriver, "gochan", "gochan", "gochan", "") - if !assert.NoError(t, err) { - return - } - gcdb.db, mock, err = sqlmock.New() - if !assert.NoError(t, err) { - return - } - deactivateQuery := `UPDATE ip_ban SET is_active = FALSE WHERE id = \(\s+` + - `SELECT ip_ban_id FROM ip_ban_appeals WHERE id = ` - deactivateAppealQuery := `INSERT INTO ip_ban_audit\s*\(\s*ip_ban_id, timestamp, ` + - `staff_id, is_active, is_thread_ban, permanent, staff_note, message, can_appeal\)\s*VALUES\(\(` + - `SELECT ip_ban_id FROM ip_ban_appeals WHERE id = ` - deleteAppealQuery := `DELETE FROM ip_ban_appeals WHERE id = ` + query := `SELECT id, staff_id, ip_ban_id, appeal_text, staff_response, is_denied FROM ip_ban_appeals` + if tC.args.banID > 0 { + switch driver { + case "mysql": + query += ` WHERE ip_ban_id = \?` + case "sqlite3": + fallthrough + case "postgres": + query += ` WHERE ip_ban_id = \$1` + } + } + if tC.args.limit > 0 { + query += " LIMIT " + strconv.Itoa(tC.args.limit) + } + expectQuery := mock.ExpectPrepare(query).ExpectQuery() + if tC.args.banID > 0 { + expectQuery.WithArgs(tC.args.banID) + } - switch sqlDriver { - case "mysql": - deactivateQuery += `\?\)` - deactivateAppealQuery += `\?\),\s*CURRENT_TIMESTAMP, \?, FALSE, FALSE, FALSE, '', '', TRUE\)` - deleteAppealQuery += `\?` - case "sqlite3": - fallthrough - case "postgres": - deactivateQuery += `\$1\)` - deactivateAppealQuery += `\$1\),\s+CURRENT_TIMESTAMP, \$2, FALSE, FALSE, FALSE, '', '', TRUE\)` - deleteAppealQuery += `\$1` - } - mock.ExpectBegin() - mock.ExpectPrepare(deactivateQuery).ExpectExec(). - WithArgs(tC.args.appealID).WillReturnResult(driver.ResultNoRows) + expectedRows := sqlmock.NewRows([]string{"id", "staff_id", "ip_ban_id", "appeal_text", "staff_response", "is_denied"}) + if len(tC.expectReturn) > 0 { + for _, expectedBan := range tC.expectReturn { + expectedRows.AddRow( + expectedBan.ID, expectedBan.StaffID, expectedBan.IPBanID, expectedBan.AppealText, + expectedBan.StaffResponse, expectedBan.IsDenied, + ) + } + } + expectQuery.WillReturnRows(expectedRows) - mock.ExpectPrepare(deactivateAppealQuery).ExpectExec(). - WithArgs(tC.args.appealID, tC.args.staffID). - WillReturnResult(driver.ResultNoRows) + got, err := GetAppeals(tC.args.banID, tC.args.limit) + if !assert.NoError(t, err) { + return + } + assert.NoError(t, mock.ExpectationsWereMet()) - mock.ExpectPrepare(deleteAppealQuery).ExpectExec(). - WithArgs(tC.args.appealID). - WillReturnResult(driver.ResultNoRows) + assert.LessOrEqual(t, len(got), tC.args.limit) + assert.Equal(t, tC.expectReturn, got) + if tC.args.banID > 0 && tC.expectReturn != nil { + assert.Equal(t, tC.args.banID, tC.expectReturn[0].ID) + } + assert.NoError(t, mock.ExpectationsWereMet()) + closeMock(t, mock) - mock.ExpectCommit() +} - if !assert.NoError(t, ApproveAppeal(tC.args.appealID, tC.args.staffID)) { - return - } - if !assert.NoError(t, mock.ExpectationsWereMet()) { - return - } - closeMock(t, mock) +func TestGetAppeals(t *testing.T) { + for _, tC := range testCasesGetAppeals { + for _, driver := range testingDBDrivers { + t.Run(fmt.Sprintf("%s (%s)", tC.name, driver), func(t *testing.T) { + testRunnerGetAppeals(t, &tC, driver) + }) + } + } +} + +func testRunnerApproveAppeal(t *testing.T, tC *testCaseApproveAppeals, sqlDriver string) { + t.Helper() + mock, err := setupMockDB(t, sqlDriver, "gochan") + if !assert.NoError(t, err) { + return + } + + deactivateQuery := `UPDATE ip_ban SET is_active = FALSE WHERE id = \(\s+` + + `SELECT ip_ban_id FROM ip_ban_appeals WHERE id = ` + deactivateAppealQuery := `INSERT INTO ip_ban_audit\s*\(\s*ip_ban_id, timestamp, ` + + `staff_id, is_active, is_thread_ban, permanent, staff_note, message, can_appeal\)\s*VALUES\(\(` + + `SELECT ip_ban_id FROM ip_ban_appeals WHERE id = ` + deleteAppealQuery := `DELETE FROM ip_ban_appeals WHERE id = ` + + switch sqlDriver { + case "mysql": + deactivateQuery += `\?\)` + deactivateAppealQuery += `\?\),\s*CURRENT_TIMESTAMP, \?, FALSE, FALSE, FALSE, '', '', TRUE\)` + deleteAppealQuery += `\?` + case "sqlite3": + fallthrough + case "postgres": + deactivateQuery += `\$1\)` + deactivateAppealQuery += `\$1\),\s+CURRENT_TIMESTAMP, \$2, FALSE, FALSE, FALSE, '', '', TRUE\)` + deleteAppealQuery += `\$1` + } + mock.ExpectBegin() + mock.ExpectPrepare(deactivateQuery).ExpectExec(). + WithArgs(tC.args.appealID).WillReturnResult(driver.ResultNoRows) + + mock.ExpectPrepare(deactivateAppealQuery).ExpectExec(). + WithArgs(tC.args.appealID, tC.args.staffID). + WillReturnResult(driver.ResultNoRows) + + mock.ExpectPrepare(deleteAppealQuery).ExpectExec(). + WithArgs(tC.args.appealID). + WillReturnResult(driver.ResultNoRows) + mock.ExpectCommit() + + assert.NoError(t, ApproveAppeal(tC.args.appealID, tC.args.staffID)) + assert.NoError(t, mock.ExpectationsWereMet()) + + closeMock(t, mock) +} + +func TestApproveAppeal(t *testing.T) { + for _, tC := range testCasesApproveAppeals { + for _, sqlDriver := range testingDBDrivers { + t.Run(fmt.Sprintf("%s (%s)", tC.name, sqlDriver), func(t *testing.T) { + testRunnerApproveAppeal(t, &tC, sqlDriver) }) } } diff --git a/pkg/gcsql/provisioning_test.go b/pkg/gcsql/provisioning_test.go index 18b2476d..296821a4 100644 --- a/pkg/gcsql/provisioning_test.go +++ b/pkg/gcsql/provisioning_test.go @@ -3,7 +3,6 @@ package gcsql import ( "testing" - "github.com/DATA-DOG/go-sqlmock" "github.com/gochan-org/gochan/pkg/config" "github.com/stretchr/testify/assert" ) @@ -29,13 +28,12 @@ func TestProvision(t *testing.T) { return } - var mock sqlmock.Sqlmock - gcdb.db, mock, err = sqlmock.New() + mock, err := setupMockDB(t, driver, "gochan") if !assert.NoError(t, err) { return } - if !assert.NoError(t, setupGochanMockDB(t, mock, "gochan", driver)) { + if !assert.NoError(t, setupAndProvisionMockDB(t, mock, driver, "gochan")) { return } closeMock(t, mock) diff --git a/pkg/gcsql/setup_test.go b/pkg/gcsql/setup_test.go index c9e10a68..b4f2fde1 100644 --- a/pkg/gcsql/setup_test.go +++ b/pkg/gcsql/setup_test.go @@ -1,3 +1,5 @@ +// this source file contains helper functions for gcsql + package gcsql import ( @@ -9,6 +11,7 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" ) var ( @@ -111,8 +114,21 @@ func goToGochanRoot(t *testing.T) (string, error) { return dir, errors.New("test running from unexpected dir, should be in gochan root or the current testing dir") } -func setupGochanMockDB(t *testing.T, mock sqlmock.Sqlmock, dbName string, dbType string) error { +func setupMockDB(t *testing.T, dbType string, dbName string) (mock sqlmock.Sqlmock, err error) { + gcdb, err = setupDBConn("localhost", dbType, dbName, "gochan", "gochan", "") + if !assert.NoError(t, err) { + return + } + gcdb.db, mock, err = sqlmock.New() + assert.NoError(t, err) + return +} + +func setupAndProvisionMockDB(t *testing.T, mock sqlmock.Sqlmock, dbType string, dbName string) error { t.Helper() + if gcdb == nil || gcdb.db == nil { + return ErrNotConnected + } mock.ExpectPrepare("CREATE DATABASE " + dbName). ExpectExec().WillReturnResult(driver.ResultNoRows)