mirror of
https://github.com/Eggbertx/gochan.git
synced 2025-09-05 11:06:23 -07:00
Refactor gcsql tests to clean them up a bit
This commit is contained in:
parent
521296ed29
commit
45ba890280
3 changed files with 148 additions and 137 deletions
|
@ -7,24 +7,11 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gochan-org/gochan/pkg/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type argsGetAppeals struct {
|
||||
banID int
|
||||
limit int
|
||||
}
|
||||
|
||||
func TestGetAppeals(t *testing.T) {
|
||||
config.SetVersion("3.10.1")
|
||||
config.SetRandomSeed("test")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
args argsGetAppeals
|
||||
expectReturn []IPBanAppeal
|
||||
}{
|
||||
var (
|
||||
testCasesGetAppeals = []testCaseGetAppeals{
|
||||
{
|
||||
name: "single appeal, no results",
|
||||
args: argsGetAppeals{1, 1},
|
||||
|
@ -48,66 +35,28 @@ func TestGetAppeals(t *testing.T) {
|
|||
expectReturn: []IPBanAppeal{{}, {}, {}},
|
||||
},
|
||||
}
|
||||
var mock sqlmock.Sqlmock
|
||||
var err error
|
||||
for _, tC := range testCases {
|
||||
for _, driver := range testingDBDrivers {
|
||||
t.Run(fmt.Sprintf("%s (%s)", tC.name, driver), func(t *testing.T) {
|
||||
gcdb, err = setupDBConn("localhost", driver, "gochan", "gochan", "gochan", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
gcdb.db, mock, err = sqlmock.New()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
query := `SELECT id, staff_id, ip_ban_id, appeal_text, staff_response, is_denied FROM ip_ban_appeals`
|
||||
if tC.args.banID > 0 {
|
||||
switch driver {
|
||||
case "mysql":
|
||||
query += ` WHERE ip_ban_id = \?`
|
||||
case "sqlite3":
|
||||
fallthrough
|
||||
case "postgres":
|
||||
query += ` WHERE ip_ban_id = \$1`
|
||||
}
|
||||
}
|
||||
if tC.args.limit > 0 {
|
||||
query += " LIMIT " + strconv.Itoa(tC.args.limit)
|
||||
}
|
||||
expectQuery := mock.ExpectPrepare(query).ExpectQuery()
|
||||
if tC.args.banID > 0 {
|
||||
expectQuery.WithArgs(tC.args.banID)
|
||||
}
|
||||
|
||||
expectedRows := sqlmock.NewRows([]string{"id", "staff_id", "ip_ban_id", "appeal_text", "staff_response", "is_denied"})
|
||||
if len(tC.expectReturn) > 0 {
|
||||
for _, expectedBan := range tC.expectReturn {
|
||||
expectedRows.AddRow(
|
||||
expectedBan.ID, expectedBan.StaffID, expectedBan.IPBanID, expectedBan.AppealText,
|
||||
expectedBan.StaffResponse, expectedBan.IsDenied,
|
||||
)
|
||||
}
|
||||
}
|
||||
expectQuery.WillReturnRows(expectedRows)
|
||||
|
||||
got, err := GetAppeals(tC.args.banID, tC.args.limit)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
assert.LessOrEqual(t, len(got), tC.args.limit)
|
||||
assert.Equal(t, tC.expectReturn, got)
|
||||
if tC.args.banID > 0 && tC.expectReturn != nil {
|
||||
assert.Equal(t, tC.args.banID, tC.expectReturn[0].ID)
|
||||
}
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
closeMock(t, mock)
|
||||
})
|
||||
}
|
||||
testCasesApproveAppeals = []testCaseApproveAppeals{
|
||||
{
|
||||
name: "approve nonexistent appeal",
|
||||
args: argsApproveAppeal{1, 1},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type testCaseGetAppeals struct {
|
||||
name string
|
||||
args argsGetAppeals
|
||||
expectReturn []IPBanAppeal
|
||||
}
|
||||
|
||||
type testCaseApproveAppeals struct {
|
||||
name string
|
||||
args argsApproveAppeal
|
||||
}
|
||||
|
||||
type argsGetAppeals struct {
|
||||
banID int
|
||||
limit int
|
||||
}
|
||||
|
||||
type argsApproveAppeal struct {
|
||||
|
@ -115,71 +64,119 @@ type argsApproveAppeal struct {
|
|||
staffID int
|
||||
}
|
||||
|
||||
func TestApproveAppeal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args argsApproveAppeal
|
||||
expectsAffectedRows bool
|
||||
}{
|
||||
{
|
||||
name: "approve nonexistent appeal",
|
||||
args: argsApproveAppeal{1, 1},
|
||||
},
|
||||
func testRunnerGetAppeals(t *testing.T, tC *testCaseGetAppeals, driver string) {
|
||||
t.Helper()
|
||||
mock, err := setupMockDB(t, driver, "gochan")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
var mock sqlmock.Sqlmock
|
||||
var err error
|
||||
for _, tC := range tests {
|
||||
for _, sqlDriver := range testingDBDrivers {
|
||||
t.Run(fmt.Sprintf("%s (%s)", tC.name, sqlDriver), func(t *testing.T) {
|
||||
gcdb, err = setupDBConn("localhost", sqlDriver, "gochan", "gochan", "gochan", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
gcdb.db, mock, err = sqlmock.New()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
deactivateQuery := `UPDATE ip_ban SET is_active = FALSE WHERE id = \(\s+` +
|
||||
`SELECT ip_ban_id FROM ip_ban_appeals WHERE id = `
|
||||
deactivateAppealQuery := `INSERT INTO ip_ban_audit\s*\(\s*ip_ban_id, timestamp, ` +
|
||||
`staff_id, is_active, is_thread_ban, permanent, staff_note, message, can_appeal\)\s*VALUES\(\(` +
|
||||
`SELECT ip_ban_id FROM ip_ban_appeals WHERE id = `
|
||||
deleteAppealQuery := `DELETE FROM ip_ban_appeals WHERE id = `
|
||||
query := `SELECT id, staff_id, ip_ban_id, appeal_text, staff_response, is_denied FROM ip_ban_appeals`
|
||||
if tC.args.banID > 0 {
|
||||
switch driver {
|
||||
case "mysql":
|
||||
query += ` WHERE ip_ban_id = \?`
|
||||
case "sqlite3":
|
||||
fallthrough
|
||||
case "postgres":
|
||||
query += ` WHERE ip_ban_id = \$1`
|
||||
}
|
||||
}
|
||||
if tC.args.limit > 0 {
|
||||
query += " LIMIT " + strconv.Itoa(tC.args.limit)
|
||||
}
|
||||
expectQuery := mock.ExpectPrepare(query).ExpectQuery()
|
||||
if tC.args.banID > 0 {
|
||||
expectQuery.WithArgs(tC.args.banID)
|
||||
}
|
||||
|
||||
switch sqlDriver {
|
||||
case "mysql":
|
||||
deactivateQuery += `\?\)`
|
||||
deactivateAppealQuery += `\?\),\s*CURRENT_TIMESTAMP, \?, FALSE, FALSE, FALSE, '', '', TRUE\)`
|
||||
deleteAppealQuery += `\?`
|
||||
case "sqlite3":
|
||||
fallthrough
|
||||
case "postgres":
|
||||
deactivateQuery += `\$1\)`
|
||||
deactivateAppealQuery += `\$1\),\s+CURRENT_TIMESTAMP, \$2, FALSE, FALSE, FALSE, '', '', TRUE\)`
|
||||
deleteAppealQuery += `\$1`
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare(deactivateQuery).ExpectExec().
|
||||
WithArgs(tC.args.appealID).WillReturnResult(driver.ResultNoRows)
|
||||
expectedRows := sqlmock.NewRows([]string{"id", "staff_id", "ip_ban_id", "appeal_text", "staff_response", "is_denied"})
|
||||
if len(tC.expectReturn) > 0 {
|
||||
for _, expectedBan := range tC.expectReturn {
|
||||
expectedRows.AddRow(
|
||||
expectedBan.ID, expectedBan.StaffID, expectedBan.IPBanID, expectedBan.AppealText,
|
||||
expectedBan.StaffResponse, expectedBan.IsDenied,
|
||||
)
|
||||
}
|
||||
}
|
||||
expectQuery.WillReturnRows(expectedRows)
|
||||
|
||||
mock.ExpectPrepare(deactivateAppealQuery).ExpectExec().
|
||||
WithArgs(tC.args.appealID, tC.args.staffID).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
got, err := GetAppeals(tC.args.banID, tC.args.limit)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
mock.ExpectPrepare(deleteAppealQuery).ExpectExec().
|
||||
WithArgs(tC.args.appealID).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
assert.LessOrEqual(t, len(got), tC.args.limit)
|
||||
assert.Equal(t, tC.expectReturn, got)
|
||||
if tC.args.banID > 0 && tC.expectReturn != nil {
|
||||
assert.Equal(t, tC.args.banID, tC.expectReturn[0].ID)
|
||||
}
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
closeMock(t, mock)
|
||||
|
||||
mock.ExpectCommit()
|
||||
}
|
||||
|
||||
if !assert.NoError(t, ApproveAppeal(tC.args.appealID, tC.args.staffID)) {
|
||||
return
|
||||
}
|
||||
if !assert.NoError(t, mock.ExpectationsWereMet()) {
|
||||
return
|
||||
}
|
||||
closeMock(t, mock)
|
||||
func TestGetAppeals(t *testing.T) {
|
||||
for _, tC := range testCasesGetAppeals {
|
||||
for _, driver := range testingDBDrivers {
|
||||
t.Run(fmt.Sprintf("%s (%s)", tC.name, driver), func(t *testing.T) {
|
||||
testRunnerGetAppeals(t, &tC, driver)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testRunnerApproveAppeal(t *testing.T, tC *testCaseApproveAppeals, sqlDriver string) {
|
||||
t.Helper()
|
||||
mock, err := setupMockDB(t, sqlDriver, "gochan")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
deactivateQuery := `UPDATE ip_ban SET is_active = FALSE WHERE id = \(\s+` +
|
||||
`SELECT ip_ban_id FROM ip_ban_appeals WHERE id = `
|
||||
deactivateAppealQuery := `INSERT INTO ip_ban_audit\s*\(\s*ip_ban_id, timestamp, ` +
|
||||
`staff_id, is_active, is_thread_ban, permanent, staff_note, message, can_appeal\)\s*VALUES\(\(` +
|
||||
`SELECT ip_ban_id FROM ip_ban_appeals WHERE id = `
|
||||
deleteAppealQuery := `DELETE FROM ip_ban_appeals WHERE id = `
|
||||
|
||||
switch sqlDriver {
|
||||
case "mysql":
|
||||
deactivateQuery += `\?\)`
|
||||
deactivateAppealQuery += `\?\),\s*CURRENT_TIMESTAMP, \?, FALSE, FALSE, FALSE, '', '', TRUE\)`
|
||||
deleteAppealQuery += `\?`
|
||||
case "sqlite3":
|
||||
fallthrough
|
||||
case "postgres":
|
||||
deactivateQuery += `\$1\)`
|
||||
deactivateAppealQuery += `\$1\),\s+CURRENT_TIMESTAMP, \$2, FALSE, FALSE, FALSE, '', '', TRUE\)`
|
||||
deleteAppealQuery += `\$1`
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare(deactivateQuery).ExpectExec().
|
||||
WithArgs(tC.args.appealID).WillReturnResult(driver.ResultNoRows)
|
||||
|
||||
mock.ExpectPrepare(deactivateAppealQuery).ExpectExec().
|
||||
WithArgs(tC.args.appealID, tC.args.staffID).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
|
||||
mock.ExpectPrepare(deleteAppealQuery).ExpectExec().
|
||||
WithArgs(tC.args.appealID).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
mock.ExpectCommit()
|
||||
|
||||
assert.NoError(t, ApproveAppeal(tC.args.appealID, tC.args.staffID))
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
closeMock(t, mock)
|
||||
}
|
||||
|
||||
func TestApproveAppeal(t *testing.T) {
|
||||
for _, tC := range testCasesApproveAppeals {
|
||||
for _, sqlDriver := range testingDBDrivers {
|
||||
t.Run(fmt.Sprintf("%s (%s)", tC.name, sqlDriver), func(t *testing.T) {
|
||||
testRunnerApproveAppeal(t, &tC, sqlDriver)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package gcsql
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gochan-org/gochan/pkg/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -29,13 +28,12 @@ func TestProvision(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
gcdb.db, mock, err = sqlmock.New()
|
||||
mock, err := setupMockDB(t, driver, "gochan")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
if !assert.NoError(t, setupGochanMockDB(t, mock, "gochan", driver)) {
|
||||
if !assert.NoError(t, setupAndProvisionMockDB(t, mock, driver, "gochan")) {
|
||||
return
|
||||
}
|
||||
closeMock(t, mock)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// this source file contains helper functions for gcsql
|
||||
|
||||
package gcsql
|
||||
|
||||
import (
|
||||
|
@ -9,6 +11,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -111,8 +114,21 @@ func goToGochanRoot(t *testing.T) (string, error) {
|
|||
return dir, errors.New("test running from unexpected dir, should be in gochan root or the current testing dir")
|
||||
}
|
||||
|
||||
func setupGochanMockDB(t *testing.T, mock sqlmock.Sqlmock, dbName string, dbType string) error {
|
||||
func setupMockDB(t *testing.T, dbType string, dbName string) (mock sqlmock.Sqlmock, err error) {
|
||||
gcdb, err = setupDBConn("localhost", dbType, dbName, "gochan", "gochan", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
gcdb.db, mock, err = sqlmock.New()
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
func setupAndProvisionMockDB(t *testing.T, mock sqlmock.Sqlmock, dbType string, dbName string) error {
|
||||
t.Helper()
|
||||
if gcdb == nil || gcdb.db == nil {
|
||||
return ErrNotConnected
|
||||
}
|
||||
|
||||
mock.ExpectPrepare("CREATE DATABASE " + dbName).
|
||||
ExpectExec().WillReturnResult(driver.ResultNoRows)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue