Skip to content

Commit

Permalink
test: remove private api usages in tests (#221)
Browse files Browse the repository at this point in the history
refactor tests, so private interfaces are not used:
* split batches in tests
* remove batchSplitter interface, which is not used after refactor
* replace quoteKeyword with simple substitution. it is not needed as
  table names in tests are not problematic

We need this for a lib/tests module split
  • Loading branch information
slsyy authored Oct 19, 2024
1 parent ef93f5c commit 2e229a2
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 120 deletions.
6 changes: 0 additions & 6 deletions clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,6 @@ func (h *clickhouse) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction
return tx.Commit()
}

// splitter is a batchSplitter interface implementation. We need it for
// ClickHouseDB because clickhouse doesn't support multi-statements.
func (*clickhouse) splitter() []byte {
return []byte(";\n")
}

func (h *clickhouse) cleanTableQuery(tableName string) string {
if h.cleanTableFn == nil {
return h.baseHelper.cleanTableQuery(tableName)
Expand Down
10 changes: 3 additions & 7 deletions clickhouse_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//go:build clickhouse
// +build clickhouse

package testfixtures

Expand All @@ -11,10 +10,7 @@ import (
)

func TestClickhouse(t *testing.T) {
testLoader(
t,
"clickhouse",
os.Getenv("CLICKHOUSE_CONN_STRING"),
"testdata/schema/clickhouse.sql",
)
db := openDB(t, "clickhouse", os.Getenv("CLICKHOUSE_CONN_STRING"))
loadSchemaInBatchesBySplitter(t, db, "testdata/schema/clickhouse.sql", []byte(";\n"))
testLoader(t, db, "clickhouse")
}
6 changes: 3 additions & 3 deletions cockroachdb_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//go:build cockroachdb
// +build cockroachdb

package testfixtures

Expand All @@ -13,11 +12,12 @@ import (

func TestCockroachDB(t *testing.T) {
for _, dialect := range []string{"postgres", "pgx"} {
db := openDB(t, dialect, os.Getenv("CRDB_CONN_STRING"))
loadSchemaInOneQuery(t, db, "testdata/schema/cockroachdb.sql")
testLoader(
t,
db,
dialect,
os.Getenv("CRDB_CONN_STRING"),
"testdata/schema/cockroachdb.sql",
DangerousSkipTestDatabaseCheck(),
UseDropConstraint(),
)
Expand Down
10 changes: 0 additions & 10 deletions helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ type queryable interface {
QueryRow(string, ...interface{}) *sql.Row
}

// batchSplitter is an interface with method which returns byte slice for
// splitting SQL batches. This need to split sql statements and run its
// separately.
//
// For Microsoft SQL Server batch splitter is "GO". For details see
// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/sql-server-utilities-statements-go
type batchSplitter interface { //nolint
splitter() []byte
}

var (
_ helper = &clickhouse{}
_ helper = &spanner{}
Expand Down
11 changes: 4 additions & 7 deletions mysql_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build mysql
//go:build mysql

package testfixtures

Expand All @@ -10,10 +10,7 @@ import (
)

func TestMySQL(t *testing.T) {
testLoader(
t,
"mysql",
os.Getenv("MYSQL_CONN_STRING"),
"testdata/schema/mysql.sql",
)
db := openDB(t, "mysql", os.Getenv("MYSQL_CONN_STRING"))
loadSchemaInOneQuery(t, db, "testdata/schema/mysql.sql")
testLoader(t, db, "mysql")
}
34 changes: 13 additions & 21 deletions postgresql_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build postgresql
//go:build postgresql

package testfixtures

Expand All @@ -11,36 +11,28 @@ import (
)

func TestPostgreSQL(t *testing.T) {
for _, dialect := range []string{"postgres", "pgx"} {
testLoader(
t,
dialect,
os.Getenv("PG_CONN_STRING"),
"testdata/schema/postgresql.sql",
)
}
testPostgreSQL(t)
}

func TestPostgreSQLWithAlterConstraint(t *testing.T) {
for _, dialect := range []string{"postgres", "pgx"} {
testLoader(
t,
dialect,
os.Getenv("PG_CONN_STRING"),
"testdata/schema/postgresql.sql",
UseAlterConstraint(),
)
}
testPostgreSQL(t, UseAlterConstraint())
}

func TestPostgreSQLWithDropConstraint(t *testing.T) {
testPostgreSQL(t, UseDropConstraint())
}

func testPostgreSQL(t *testing.T, additionalOptions ...func(*Loader) error) {
t.Helper()
for _, dialect := range []string{"postgres", "pgx"} {
db := openDB(t, dialect, os.Getenv("PG_CONN_STRING"))
loadSchemaInOneQuery(t, db, "testdata/schema/postgresql.sql")
testLoader(
t,
db,
dialect,
os.Getenv("PG_CONN_STRING"),
"testdata/schema/postgresql.sql",
UseDropConstraint(),
additionalOptions...,
)
}

}
6 changes: 0 additions & 6 deletions spanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ func (h *spanner) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (
return h.dropAndRecreateConstraints(db, loadFn)
}

// splitter is a batchSplitter interface implementation. We need it for
// spanner because spanner doesn't support multi-statements.
func (*spanner) splitter() []byte {
return []byte(";\n")
}

func (h *spanner) cleanTableQuery(tableName string) string {
if h.cleanTableFn == nil {
return h.baseHelper.cleanTableQuery(tableName)
Expand Down
10 changes: 3 additions & 7 deletions spanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,9 @@ import (
func TestSpanner(t *testing.T) {
prepareSpannerDB(t)

testLoader(
t,
"spanner",
os.Getenv("SPANNER_CONN_STRING"),
"testdata/schema/spanner.sql",
DangerousSkipTestDatabaseCheck(),
)
db := openDB(t, "spanner", os.Getenv("SPANNER_CONN_STRING"))
loadSchemaInBatchesBySplitter(t, db, "testdata/schema/spanner.sql", []byte(";\n"))
testLoader(t, db, "spanner", DangerousSkipTestDatabaseCheck())
}

func prepareSpannerDB(t *testing.T) {
Expand Down
11 changes: 4 additions & 7 deletions sqlite_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build sqlite
//go:build sqlite

package testfixtures

Expand All @@ -10,10 +10,7 @@ import (
)

func TestSQLite(t *testing.T) {
testLoader(
t,
"sqlite3",
os.Getenv("SQLITE_CONN_STRING"),
"testdata/schema/sqlite.sql",
)
db := openDB(t, "sqlite3", os.Getenv("SQLITE_CONN_STRING"))
loadSchemaInOneQuery(t, db, "testdata/schema/sqlite.sql")
testLoader(t, db, "sqlite3")
}
8 changes: 0 additions & 8 deletions sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,3 @@ func (h *sqlserver) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction)

return tx.Commit()
}

// splitter is a batchSplitter interface implementation. We need it for
// SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
// could not be executed in the same batch.
// See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
func (*sqlserver) splitter() []byte {
return []byte("GO\n")
}
22 changes: 11 additions & 11 deletions sqlserver_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build sqlserver
//go:build sqlserver

package testfixtures

Expand All @@ -10,21 +10,21 @@ import (
)

func TestSQLServer(t *testing.T) {
testLoader(
t,
"sqlserver",
os.Getenv("SQLSERVER_CONN_STRING"),
"testdata/schema/sqlserver.sql",
DangerousSkipTestDatabaseCheck(),
)
testSQLServer(t, "sqlserver")
}

func TestDeprecatedMssql(t *testing.T) {
testSQLServer(t, "mssql")
}

func testSQLServer(t *testing.T, dialect string) {
t.Helper()
db := openDB(t, dialect, os.Getenv("SQLSERVER_CONN_STRING"))
loadSchemaInBatchesBySplitter(t, db, "testdata/schema/sqlserver.sql", []byte("GO\n"))
testLoader(
t,
"mssql",
os.Getenv("SQLSERVER_CONN_STRING"),
"testdata/schema/sqlserver.sql",
db,
dialect,
DangerousSkipTestDatabaseCheck(),
)
}
61 changes: 34 additions & 27 deletions testfixtures_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"database/sql"
"embed"
"errors"
"fmt"
"os"
"testing"
Expand All @@ -26,64 +27,70 @@ func TestFixtureFile(t *testing.T) {
func TestRequiredOptions(t *testing.T) {
t.Run("DatabaseIsRequired", func(t *testing.T) {
_, err := New()
if err != errDatabaseIsRequired {
if !errors.Is(err, errDatabaseIsRequired) {
t.Error("should return an error if database if not given")
}
})

t.Run("DialectIsRequired", func(t *testing.T) {
_, err := New(Database(&sql.DB{}))
if err != errDialectIsRequired {
if !errors.Is(err, errDialectIsRequired) {
t.Error("should return an error if dialect if not given")
}
})
}

func testLoader(t *testing.T, dialect, connStr, schemaFilePath string, additionalOptions ...func(*Loader) error) { //nolint
func openDB(t *testing.T, dialect, connStr string) *sql.DB { //nolint:unused
t.Helper()
db, err := sql.Open(dialect, connStr)
if err != nil {
t.Errorf("failed to open database: %v", err)
return
}
defer db.Close()
t.Cleanup(func() {
_ = db.Close()
})

if err := db.Ping(); err != nil {
t.Errorf("failed to connect to database: %v", err)
return
}
return db
}

func loadSchemaInOneQuery(t *testing.T, db *sql.DB, schemaFilePath string) { //nolint:unused
t.Helper()
schema, err := os.ReadFile(schemaFilePath)
if err != nil {
t.Errorf("cannot read schema file: %v", err)
return
}
helper, err := helperForDialect(dialect)
loadSchemaInBatches(t, db, [][]byte{schema})
}

func loadSchemaInBatchesBySplitter(t *testing.T, db *sql.DB, schemaFilePath string, splitter []byte) { //nolint:unused
t.Helper()
schema, err := os.ReadFile(schemaFilePath)
if err != nil {
t.Errorf("cannot get helper: %v", err)
return
}
if err := helper.init(db); err != nil {
t.Errorf("cannot init helper: %v", err)
t.Errorf("cannot read schema file: %v", err)
return
}
batches := bytes.Split(schema, splitter)
loadSchemaInBatches(t, db, batches)
}

var batches [][]byte
if h, ok := helper.(batchSplitter); ok {
batches = append(batches, bytes.Split(schema, h.splitter())...)
} else {
batches = append(batches, schema)
}

func loadSchemaInBatches(t *testing.T, db *sql.DB, batches [][]byte) { //nolint:unused
t.Helper()
for _, b := range batches {
if len(b) == 0 {
continue
}
if _, err = db.Exec(string(b)); err != nil {
if _, err := db.Exec(string(b)); err != nil {
t.Errorf("cannot load schema: %v", err)
return
}
}
}

func testLoader(t *testing.T, db *sql.DB, dialect string, additionalOptions ...func(*Loader) error) { //nolint:unused
t.Run("LoadFromDirectory", func(t *testing.T) {
options := append(
[]func(*Loader) error{
Expand Down Expand Up @@ -524,18 +531,18 @@ func testLoader(t *testing.T, dialect, connStr, schemaFilePath string, additiona
// sequence issues.

var sql string
switch helper.paramType() {
case paramTypeDollar:
switch dialect {
case "postgres", "pgx", "clickhouse":
sql = "INSERT INTO posts (title, content, created_at, updated_at) VALUES ($1, $2, $3, $4)"
case paramTypeQuestion:
case "mysql", "sqlite3", "mssql":
sql = "INSERT INTO posts (title, content, created_at, updated_at) VALUES (?, ?, ?, ?)"
case paramTypeAtSign:
case "sqlserver", "spanner":
sql = "INSERT INTO posts (title, content, created_at, updated_at) VALUES (@p1, @p2, @p3, @p4)"
default:
panic("unrecognized param type")
t.Fatalf("undefined param type for %s dialect, modify switch statement", dialect)
}

_, err = db.Exec(sql, "Post title", "Post content", time.Now(), time.Now())
_, err := db.Exec(sql, "Post title", "Post content", time.Now(), time.Now())
if err != nil {
t.Errorf("cannot insert post: %v", err)
}
Expand All @@ -553,7 +560,7 @@ func assertFixturesLoaded(t *testing.T, l *Loader) { //nolint

func assertCount(t *testing.T, l *Loader, table string, expectedCount int) { //nolint
count := 0
sql := fmt.Sprintf("SELECT COUNT(*) FROM %s", l.helper.quoteKeyword(table))
sql := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)

row := l.db.QueryRow(sql)
if err := row.Scan(&count); err != nil {
Expand Down

0 comments on commit 2e229a2

Please sign in to comment.