diff --git a/go/vt/mysqlctl/tmutils/schema.go b/go/vt/mysqlctl/tmutils/schema.go index a723e58a17d..41842d40c07 100644 --- a/go/vt/mysqlctl/tmutils/schema.go +++ b/go/vt/mysqlctl/tmutils/schema.go @@ -119,7 +119,7 @@ func NewTableFilter(tables, excludeTables []string, includeViews bool) (*TableFi return nil, fmt.Errorf("cannot compile regexp %v for excludeTable: %v", table, err) } - f.excludeTableREs = append(f.tableREs, re) + f.excludeTableREs = append(f.excludeTableREs, re) } else { f.excludeTableNames = append(f.excludeTableNames, table) } diff --git a/go/vt/mysqlctl/tmutils/schema_test.go b/go/vt/mysqlctl/tmutils/schema_test.go index b355206ff7f..472093cb869 100644 --- a/go/vt/mysqlctl/tmutils/schema_test.go +++ b/go/vt/mysqlctl/tmutils/schema_test.go @@ -19,9 +19,10 @@ package tmutils import ( "errors" "fmt" - "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" @@ -135,9 +136,7 @@ func TestToSQLStrings(t *testing.T) { for _, tc := range testcases { got := SchemaDefinitionToSQLStrings(tc.input) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("ToSQLStrings() on SchemaDefinition %v returned %v; want %v", tc.input, got, tc.want) - } + assert.Equal(t, tc.want, got) } } @@ -156,12 +155,7 @@ func testDiff(t *testing.T, left, right *tabletmanagerdatapb.SchemaDefinition, l } } } - - if !equal { - t.Logf("Expected: %v", expected) - t.Logf("Actual: %v", actual) - t.Fail() - } + assert.Truef(t, equal, "expected: %v, actual: %v", expected, actual) } func TestSchemaDiff(t *testing.T) { @@ -433,6 +427,26 @@ func TestTableFilter(t *testing.T) { included: false, }, + { + desc: "exclude table list does not list table", + excludeTables: []string{"nomatch1", "nomatch2", "/nomatch3/", "/nomatch4/", "/nomatch5/"}, + includeViews: true, + + tableName: excludedTable, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "exclude table list with re match", + excludeTables: []string{"nomatch1", "nomatch2", "/nomatch3/", "/" + excludedTable + "/", "/nomatch5/"}, + includeViews: true, + + tableName: excludedTable, + tableType: TableBaseTable, + + included: false, + }, { desc: "bad table regexp", tables: []string{"/*/"}, @@ -450,18 +464,16 @@ func TestTableFilter(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { f, err := NewTableFilter(tc.tables, tc.excludeTables, tc.includeViews) - if tc.hasErr != (err != nil) { - t.Fatalf("hasErr not right: %v, tc: %+v", err, tc) - } - if tc.hasErr { + assert.Error(t, err) return } + assert.NoError(t, err) + assert.Equal(t, len(tc.tables), len(f.tableNames)+len(f.tableREs)) + assert.Equal(t, len(tc.excludeTables), len(f.excludeTableNames)+len(f.excludeTableREs)) included := f.Includes(tc.tableName, tc.tableType) - if tc.included != included { - t.Fatalf("included is not right: %v\nfilter: %+v\ntc: %+v", included, f, tc) - } + assert.Equalf(t, tc.included, included, "filter: %v", f) }) } } @@ -638,21 +650,15 @@ func TestFilterTables(t *testing.T) { } for _, tc := range testcases { - got, err := FilterTables(tc.input, tc.tables, tc.excludeTables, tc.includeViews) - if tc.wantError != nil { - if err == nil { - t.Fatalf("FilterTables() test '%v' on SchemaDefinition %v did not return an error (result: %v), but should have, wantError %v", tc.desc, tc.input, got, tc.wantError) - } - if err.Error() != tc.wantError.Error() { - t.Errorf("FilterTables() test '%v' on SchemaDefinition %v returned wrong error '%v'; wanted error '%v'", tc.desc, tc.input, err, tc.wantError) - } - } else { - if err != nil { - t.Errorf("FilterTables() test '%v' on SchemaDefinition %v failed with error %v, want %v", tc.desc, tc.input, err, tc.want) - } - if !proto.Equal(got, tc.want) { - t.Errorf("FilterTables() test '%v' on SchemaDefinition %v returned %v; want %v", tc.desc, tc.input, got, tc.want) + t.Run(tc.desc, func(t *testing.T) { + got, err := FilterTables(tc.input, tc.tables, tc.excludeTables, tc.includeViews) + if tc.wantError != nil { + require.Error(t, err) + require.Equal(t, tc.wantError, err) + } else { + assert.NoError(t, err) + assert.Truef(t, proto.Equal(tc.want, got), "wanted: %v, got: %v", tc.want, got) } - } + }) } }