diff --git a/internal/loader.go b/internal/loader.go index a89bab8..78a017d 100644 --- a/internal/loader.go +++ b/internal/loader.go @@ -137,6 +137,16 @@ func (tl *TypeLoader) LoadTable(args *ArgType) (map[string]*Type, error) { tableMap[ti.TableName] = typeTpl } + // validate custom type tables + if tl.CustomTypes != nil { + for _, customTable := range tl.CustomTypes.Tables { + _, ok := tableMap[customTable.Name] + if !ok { + return nil, fmt.Errorf("unknown custom type table: %s", customTable.Name) + } + } + } + return tableMap, nil } @@ -197,6 +207,21 @@ func (tl *TypeLoader) LoadColumns(args *ArgType, typeTpl *Type) error { } columnTypes := tl.tableCustomTypes(typeTpl.Table.TableName) + + // validate custom type columns + if columnTypes != nil { + columnSet := map[string]struct{}{} + for _, column := range columnList { + columnSet[column.ColumnName] = struct{}{} + } + + for k, _ := range columnTypes { + if _, ok := columnSet[k]; !ok { + return fmt.Errorf("unknown custom type column %s in the table %s", k, typeTpl.Table.TableName) + } + } + } + // process columns for _, c := range columnList { ignore := false diff --git a/test/testdata/custom_column_types.yml b/test/testdata/custom_column_types.yml index 6f2bc52..a0ad380 100644 --- a/test/testdata/custom_column_types.yml +++ b/test/testdata/custom_column_types.yml @@ -26,5 +26,4 @@ tables: - name: "FullTypes" columns: FTInt: "int32" - FTInitNull: "uint64" FTFloat: "float32" diff --git a/v2/loader/loader.go b/v2/loader/loader.go index 7b0c368..8098544 100644 --- a/v2/loader/loader.go +++ b/v2/loader/loader.go @@ -164,6 +164,14 @@ func (tl *TypeLoader) LoadTable() (map[string]*models.Type, error) { tableMap[ti.TableName] = typeTpl } + // validate custom type tables + for _, customTable := range tl.config.Tables { + _, ok := tableMap[customTable.Name] + if !ok { + return nil, fmt.Errorf("unknown custom type table %s", customTable.Name) + } + } + return tableMap, nil } @@ -225,6 +233,21 @@ func (tl *TypeLoader) LoadColumns(typeTpl *models.Type) error { } columnTypes := tl.tableCustomTypes(typeTpl.TableName) + + // validate custom type columns + if columnTypes != nil { + columnSet := map[string]struct{}{} + for _, column := range columnList { + columnSet[column.ColumnName] = struct{}{} + } + + for k, _ := range columnTypes { + if _, ok := columnSet[k]; !ok { + return fmt.Errorf("unknown custom type column %s in the table %s", k, typeTpl.TableName) + } + } + } + // process columns for _, c := range columnList { ignore := false diff --git a/v2/loader/loader_test.go b/v2/loader/loader_test.go index 92bfbfd..cd43413 100644 --- a/v2/loader/loader_test.go +++ b/v2/loader/loader_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "go.mercari.io/yo/v2/config" "go.mercari.io/yo/v2/internal" "go.mercari.io/yo/v2/models" ) @@ -78,8 +79,6 @@ ALTER TABLE Foo ADD CONSTRAINT FK_CustomerOrder FOREIGN KEY (CustomerID) REFEREN ) func TestLoader(t *testing.T) { - dir := t.TempDir() - table := []struct { name string opt Option @@ -345,51 +344,158 @@ func TestLoader(t *testing.T) { for _, tc := range table { t.Run(tc.name, func(t *testing.T) { - f, err := ioutil.TempFile(dir, "") - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - _, _ = f.Write([]byte(tc.schema)) - _ = f.Close() - path := f.Name() + l := setUpTypeLoader(t, tc.schema, tc.opt) - source, err := NewSchemaParserSource(path) + schema, err := l.LoadSchema() if err != nil { - t.Fatalf("failed to create schema parser source: %v", err) + t.Fatalf("failed to load schema: %v", err) } - inflector, err := internal.NewInflector(nil) - if err != nil { - t.Fatalf("failed to create inflector: %v", err) - } + compareSchemas(t, schema, tc.expectedSchema) + }) + } +} + +func TestLoader_CustomTypes(t *testing.T) { + table := []struct { + name string + opt Option + schema string + expectedSchema *models.Schema + expectedErr string + }{ + { + name: "Custom type table does not exist", + opt: Option{ + Config: &config.Config{ + Tables: []config.Table{ + { + Name: "UnknownTable", + }, + }, + }, + }, + schema: simpleSchema, + expectedErr: "unknown custom type table UnknownTable", + }, + { + name: "Custom type column does not exist", + opt: Option{ + Config: &config.Config{ + Tables: []config.Table{ + { + Name: "Simple", + Columns: []config.Column{ + { + Name: "UnknownColumn", + CustomType: "UnknownCustomColumn", + }, + }, + }, + }, + }, + }, + schema: simpleSchema, + expectedErr: "unknown custom type column UnknownColumn in the table Simple", + }, + { + name: "Success", + opt: Option{ + Config: &config.Config{ + Tables: []config.Table{ + { + Name: "Simple", + Columns: []config.Column{ + { + Name: "Value", + CustomType: "Value", + }, + }, + }, + }, + }, + }, + schema: simpleSchema, + expectedSchema: &models.Schema{ + Types: []*models.Type{ + { + Name: "Simple", + PrimaryKeyFields: []*models.Field{ + {ColumnName: "Id"}, + }, + Fields: []*models.Field{ + { + Name: "ID", + Type: "int64", + OriginalType: "int64", + NullValue: "0", + Len: -1, + ColumnName: "Id", + SpannerDataType: "INT64", + IsNotNull: true, + IsPrimaryKey: true, + }, + { + Name: "Value", + Type: "Value", + OriginalType: "string", + NullValue: `""`, + Len: 32, + ColumnName: "Value", + SpannerDataType: "STRING(32)", + IsNotNull: true, + IsPrimaryKey: false, + }, + }, + TableName: "Simple", + Indexes: []*models.Index{ + { + Name: "SimpleIndex", + FuncName: "SimplesBySimpleIndex", + LegacyFuncName: "SimplesByValue", + Fields: []*models.Field{ + {ColumnName: "Value"}, + }, + IndexName: "SimpleIndex", + }, + { + Name: "SimpleIndex2", + FuncName: "SimpleBySimpleIndex2", + LegacyFuncName: "SimpleByIDValue", + Fields: []*models.Field{ + {ColumnName: "Id"}, + {ColumnName: "Value"}, + }, + IndexName: "SimpleIndex2", + IsUnique: true, + }, + }, + }, + }, + }, + }, + } + + for _, tc := range table { + t.Run(tc.name, func(t *testing.T) { + l := setUpTypeLoader(t, tc.schema, tc.opt) - l := NewTypeLoader(source, inflector, tc.opt) schema, err := l.LoadSchema() - if err != nil { - t.Fatalf("failed to load schema: %v", err) - } - if diff := cmp.Diff(schema, tc.expectedSchema, - cmp.Transformer("FilterInTypePrimaryKeyFields", func(in *models.Type) *models.Type { - if in == nil { - return in - } - for i := range in.PrimaryKeyFields { - f := in.PrimaryKeyFields[i] - in.PrimaryKeyFields[i] = &models.Field{ColumnName: f.ColumnName} - } - return in - }), - cmp.Transformer("FilterInIndexFields", func(in *models.Index) *models.Index { - for i := range in.Fields { - f := in.Fields[i] - in.Fields[i] = &models.Field{ColumnName: f.ColumnName} - } - return in - }), - cmpopts.IgnoreFields(models.Index{}, "Type"), - ); diff != "" { - t.Errorf("(-got, +want)\n%s", diff) + if tc.expectedErr != "" { + if err == nil { + t.Fatal("expected to load schema failure") + } + + if err.Error() != tc.expectedErr { + t.Fatalf("unexpected error: expected: %s, actual: %s", tc.expectedErr, err.Error()) + } + } else { + if err != nil { + t.Fatalf("failed to load schema: %v", err) + } + + compareSchemas(t, schema, tc.expectedSchema) } }) } @@ -506,3 +612,56 @@ func Test_setIndexesToTables(t *testing.T) { }) } } + +func setUpTypeLoader(t *testing.T, schema string, opt Option) *TypeLoader { + t.Helper() + + dir := t.TempDir() + + f, err := ioutil.TempFile(dir, "") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + _, _ = f.Write([]byte(schema)) + _ = f.Close() + path := f.Name() + + source, err := NewSchemaParserSource(path) + if err != nil { + t.Fatalf("failed to create schema parser source: %v", err) + } + + inflector, err := internal.NewInflector(nil) + if err != nil { + t.Fatalf("failed to create inflector: %v", err) + } + + return NewTypeLoader(source, inflector, opt) +} + +func compareSchemas(t *testing.T, actual *models.Schema, expected *models.Schema) { + t.Helper() + + if diff := cmp.Diff(actual, expected, + cmp.Transformer("FilterInTypePrimaryKeyFields", func(in *models.Type) *models.Type { + if in == nil { + return in + } + for i := range in.PrimaryKeyFields { + f := in.PrimaryKeyFields[i] + in.PrimaryKeyFields[i] = &models.Field{ColumnName: f.ColumnName} + } + return in + }), + cmp.Transformer("FilterInIndexFields", func(in *models.Index) *models.Index { + for i := range in.Fields { + f := in.Fields[i] + in.Fields[i] = &models.Field{ColumnName: f.ColumnName} + } + return in + }), + cmpopts.IgnoreFields(models.Index{}, "Type"), + ); diff != "" { + t.Errorf("(-got, +want)\n%s", diff) + } +}