Skip to content

Commit

Permalink
Merge pull request #84 from shuheiktgw/custom_type_validation
Browse files Browse the repository at this point in the history
Fails if invalid custom types are specified
  • Loading branch information
kazegusuri authored Jan 12, 2022
2 parents 74a022c + 3d86460 commit f04de87
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 41 deletions.
25 changes: 25 additions & 0 deletions internal/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ func (tl *TypeLoader) LoadTable(args *ArgType) (map[string]*Type, error) {
tableMap[ti.TableName] = typeTpl
}

// validate custom type tables
if tl.CustomTypes != nil {
for _, customTable := range tl.CustomTypes.Tables {
_, ok := tableMap[customTable.Name]
if !ok {
return nil, fmt.Errorf("unknown custom type table: %s", customTable.Name)
}
}
}

return tableMap, nil
}

Expand Down Expand Up @@ -197,6 +207,21 @@ func (tl *TypeLoader) LoadColumns(args *ArgType, typeTpl *Type) error {
}

columnTypes := tl.tableCustomTypes(typeTpl.Table.TableName)

// validate custom type columns
if columnTypes != nil {
columnSet := map[string]struct{}{}
for _, column := range columnList {
columnSet[column.ColumnName] = struct{}{}
}

for k, _ := range columnTypes {
if _, ok := columnSet[k]; !ok {
return fmt.Errorf("unknown custom type column %s in the table %s", k, typeTpl.Table.TableName)
}
}
}

// process columns
for _, c := range columnList {
ignore := false
Expand Down
1 change: 0 additions & 1 deletion test/testdata/custom_column_types.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ tables:
- name: "FullTypes"
columns:
FTInt: "int32"
FTInitNull: "uint64"
FTFloat: "float32"
23 changes: 23 additions & 0 deletions v2/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ func (tl *TypeLoader) LoadTable() (map[string]*models.Type, error) {
tableMap[ti.TableName] = typeTpl
}

// validate custom type tables
for _, customTable := range tl.config.Tables {
_, ok := tableMap[customTable.Name]
if !ok {
return nil, fmt.Errorf("unknown custom type table %s", customTable.Name)
}
}

return tableMap, nil
}

Expand Down Expand Up @@ -225,6 +233,21 @@ func (tl *TypeLoader) LoadColumns(typeTpl *models.Type) error {
}

columnTypes := tl.tableCustomTypes(typeTpl.TableName)

// validate custom type columns
if columnTypes != nil {
columnSet := map[string]struct{}{}
for _, column := range columnList {
columnSet[column.ColumnName] = struct{}{}
}

for k, _ := range columnTypes {
if _, ok := columnSet[k]; !ok {
return fmt.Errorf("unknown custom type column %s in the table %s", k, typeTpl.TableName)
}
}
}

// process columns
for _, c := range columnList {
ignore := false
Expand Down
239 changes: 199 additions & 40 deletions v2/loader/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go.mercari.io/yo/v2/config"
"go.mercari.io/yo/v2/internal"
"go.mercari.io/yo/v2/models"
)
Expand Down Expand Up @@ -78,8 +79,6 @@ ALTER TABLE Foo ADD CONSTRAINT FK_CustomerOrder FOREIGN KEY (CustomerID) REFEREN
)

func TestLoader(t *testing.T) {
dir := t.TempDir()

table := []struct {
name string
opt Option
Expand Down Expand Up @@ -345,51 +344,158 @@ func TestLoader(t *testing.T) {

for _, tc := range table {
t.Run(tc.name, func(t *testing.T) {
f, err := ioutil.TempFile(dir, "")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
_, _ = f.Write([]byte(tc.schema))
_ = f.Close()
path := f.Name()
l := setUpTypeLoader(t, tc.schema, tc.opt)

source, err := NewSchemaParserSource(path)
schema, err := l.LoadSchema()
if err != nil {
t.Fatalf("failed to create schema parser source: %v", err)
t.Fatalf("failed to load schema: %v", err)
}

inflector, err := internal.NewInflector(nil)
if err != nil {
t.Fatalf("failed to create inflector: %v", err)
}
compareSchemas(t, schema, tc.expectedSchema)
})
}
}

func TestLoader_CustomTypes(t *testing.T) {
table := []struct {
name string
opt Option
schema string
expectedSchema *models.Schema
expectedErr string
}{
{
name: "Custom type table does not exist",
opt: Option{
Config: &config.Config{
Tables: []config.Table{
{
Name: "UnknownTable",
},
},
},
},
schema: simpleSchema,
expectedErr: "unknown custom type table UnknownTable",
},
{
name: "Custom type column does not exist",
opt: Option{
Config: &config.Config{
Tables: []config.Table{
{
Name: "Simple",
Columns: []config.Column{
{
Name: "UnknownColumn",
CustomType: "UnknownCustomColumn",
},
},
},
},
},
},
schema: simpleSchema,
expectedErr: "unknown custom type column UnknownColumn in the table Simple",
},
{
name: "Success",
opt: Option{
Config: &config.Config{
Tables: []config.Table{
{
Name: "Simple",
Columns: []config.Column{
{
Name: "Value",
CustomType: "Value",
},
},
},
},
},
},
schema: simpleSchema,
expectedSchema: &models.Schema{
Types: []*models.Type{
{
Name: "Simple",
PrimaryKeyFields: []*models.Field{
{ColumnName: "Id"},
},
Fields: []*models.Field{
{
Name: "ID",
Type: "int64",
OriginalType: "int64",
NullValue: "0",
Len: -1,
ColumnName: "Id",
SpannerDataType: "INT64",
IsNotNull: true,
IsPrimaryKey: true,
},
{
Name: "Value",
Type: "Value",
OriginalType: "string",
NullValue: `""`,
Len: 32,
ColumnName: "Value",
SpannerDataType: "STRING(32)",
IsNotNull: true,
IsPrimaryKey: false,
},
},
TableName: "Simple",
Indexes: []*models.Index{
{
Name: "SimpleIndex",
FuncName: "SimplesBySimpleIndex",
LegacyFuncName: "SimplesByValue",
Fields: []*models.Field{
{ColumnName: "Value"},
},
IndexName: "SimpleIndex",
},
{
Name: "SimpleIndex2",
FuncName: "SimpleBySimpleIndex2",
LegacyFuncName: "SimpleByIDValue",
Fields: []*models.Field{
{ColumnName: "Id"},
{ColumnName: "Value"},
},
IndexName: "SimpleIndex2",
IsUnique: true,
},
},
},
},
},
},
}

for _, tc := range table {
t.Run(tc.name, func(t *testing.T) {
l := setUpTypeLoader(t, tc.schema, tc.opt)

l := NewTypeLoader(source, inflector, tc.opt)
schema, err := l.LoadSchema()
if err != nil {
t.Fatalf("failed to load schema: %v", err)
}

if diff := cmp.Diff(schema, tc.expectedSchema,
cmp.Transformer("FilterInTypePrimaryKeyFields", func(in *models.Type) *models.Type {
if in == nil {
return in
}
for i := range in.PrimaryKeyFields {
f := in.PrimaryKeyFields[i]
in.PrimaryKeyFields[i] = &models.Field{ColumnName: f.ColumnName}
}
return in
}),
cmp.Transformer("FilterInIndexFields", func(in *models.Index) *models.Index {
for i := range in.Fields {
f := in.Fields[i]
in.Fields[i] = &models.Field{ColumnName: f.ColumnName}
}
return in
}),
cmpopts.IgnoreFields(models.Index{}, "Type"),
); diff != "" {
t.Errorf("(-got, +want)\n%s", diff)
if tc.expectedErr != "" {
if err == nil {
t.Fatal("expected to load schema failure")
}

if err.Error() != tc.expectedErr {
t.Fatalf("unexpected error: expected: %s, actual: %s", tc.expectedErr, err.Error())
}
} else {
if err != nil {
t.Fatalf("failed to load schema: %v", err)
}

compareSchemas(t, schema, tc.expectedSchema)
}
})
}
Expand Down Expand Up @@ -506,3 +612,56 @@ func Test_setIndexesToTables(t *testing.T) {
})
}
}

func setUpTypeLoader(t *testing.T, schema string, opt Option) *TypeLoader {
t.Helper()

dir := t.TempDir()

f, err := ioutil.TempFile(dir, "")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
_, _ = f.Write([]byte(schema))
_ = f.Close()
path := f.Name()

source, err := NewSchemaParserSource(path)
if err != nil {
t.Fatalf("failed to create schema parser source: %v", err)
}

inflector, err := internal.NewInflector(nil)
if err != nil {
t.Fatalf("failed to create inflector: %v", err)
}

return NewTypeLoader(source, inflector, opt)
}

func compareSchemas(t *testing.T, actual *models.Schema, expected *models.Schema) {
t.Helper()

if diff := cmp.Diff(actual, expected,
cmp.Transformer("FilterInTypePrimaryKeyFields", func(in *models.Type) *models.Type {
if in == nil {
return in
}
for i := range in.PrimaryKeyFields {
f := in.PrimaryKeyFields[i]
in.PrimaryKeyFields[i] = &models.Field{ColumnName: f.ColumnName}
}
return in
}),
cmp.Transformer("FilterInIndexFields", func(in *models.Index) *models.Index {
for i := range in.Fields {
f := in.Fields[i]
in.Fields[i] = &models.Field{ColumnName: f.ColumnName}
}
return in
}),
cmpopts.IgnoreFields(models.Index{}, "Type"),
); diff != "" {
t.Errorf("(-got, +want)\n%s", diff)
}
}

0 comments on commit f04de87

Please sign in to comment.