diff --git a/schema_salad/avro/schema.py b/schema_salad/avro/schema.py index 48e0ef60..c5d33bdf 100644 --- a/schema_salad/avro/schema.py +++ b/schema_salad/avro/schema.py @@ -620,7 +620,7 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) -> raise SchemaParseException(fail_msg) -def is_subtype(existing: PropType, new: PropType) -> bool: +def is_subtype(types: Dict[str, Any], existing: PropType, new: PropType) -> bool: """Check if a new type specification is compatible with an existing type spec.""" if existing == new: return True @@ -632,46 +632,35 @@ def is_subtype(existing: PropType, new: PropType) -> bool: if isinstance(new, list) and "null" in new: return False return True - if ( - isinstance(existing, dict) - and "type" in existing - and existing["type"] == "array" - and isinstance(new, dict) - and "type" in new - and new["type"] == "array" - ): - return is_subtype(existing["items"], new["items"]) - if ( - isinstance(existing, dict) - and "type" in existing - and existing["type"] == "enum" - and isinstance(new, dict) - and "type" in new - and new["type"] == "enum" - ): - return is_subtype(existing["symbols"], new["symbols"]) - if ( - isinstance(existing, dict) - and "type" in existing - and existing["type"] == "record" - and isinstance(new, dict) - and "type" in new - and new["type"] == "record" - ): - for new_field in cast(List[Dict[str, Any]], new["fields"]): - new_field_missing = True - for existing_field in cast(List[Dict[str, Any]], existing["fields"]): - if new_field["name"] == existing_field["name"]: - if not is_subtype(existing_field["type"], new_field["type"]): - return False - new_field_missing = False - if new_field_missing: - return False - return True + if isinstance(existing, str) and existing in types: + return is_subtype(types, types[existing], new) + if isinstance(new, str) and new in types: + return is_subtype(types, existing, types[new]) + if isinstance(existing, dict) and isinstance(new, dict): + if "extends" in new and new["extends"] == existing.get("name"): + return True + if existing.get("type") == "array" and new.get("type") == "array": + return is_subtype(types, existing["items"], new["items"]) + if existing.get("type") == "enum" and new.get("type") == "enum": + return is_subtype(types, existing["symbols"], new["symbols"]) + if existing.get("type") == "record" and new.get("type") == "record": + for new_field in cast(List[Dict[str, Any]], new["fields"]): + new_field_missing = True + for existing_field in cast(List[Dict[str, Any]], existing["fields"]): + if new_field["name"] == existing_field["name"]: + if not is_subtype(types, existing_field["type"], new_field["type"]): + return False + new_field_missing = False + if new_field_missing: + return False + return True if isinstance(existing, list) and isinstance(new, list): missing = False - for _type in new: - if _type not in existing and (not is_subtype(existing, cast(PropType, _type))): + for _type_new in new: + if _type_new not in existing and not any( + is_subtype(types, cast(PropType, _type_existing), cast(PropType, _type_new)) + for _type_existing in existing + ): missing = True return not missing return False diff --git a/schema_salad/schema.py b/schema_salad/schema.py index 8b990420..5f1c7921 100644 --- a/schema_salad/schema.py +++ b/schema_salad/schema.py @@ -594,6 +594,7 @@ def extend_and_specialize(items: List[Dict[str, Any]], loader: Loader) -> List[D """Apply 'extend' and 'specialize' to fully materialize derived record types.""" items2 = deepcopy_strip(items) types = {i["name"]: i for i in items2} # type: Dict[str, Any] + types.update({k[len(saladp) :]: v for k, v in types.items() if k.startswith(saladp)}) results = [] for stype in items2: @@ -654,7 +655,7 @@ def extend_and_specialize(items: List[Dict[str, Any]], loader: Loader) -> List[D field = exfield else: # make sure field name has not been used yet - if not is_subtype(exfield["type"], field["type"]): + if not is_subtype(types, exfield["type"], field["type"]): raise SchemaParseException( f"Field name {field['name']} already in use with " "incompatible type. " diff --git a/schema_salad/tests/test_schema/avro_subtype_nested.yml b/schema_salad/tests/test_schema/avro_subtype_nested.yml new file mode 100644 index 00000000..d335198a --- /dev/null +++ b/schema_salad/tests/test_schema/avro_subtype_nested.yml @@ -0,0 +1,35 @@ +$base: "https://example.com/nested_schema#" + +$namespaces: + bs: "https://example.com/base_schema#" + dv: "https://example.com/derived_schema#" + +$graph: + +- $import: avro_subtype.yml + +- type: record + name: AbstractContainer + abstract: true + doc: | + This is an abstract container thing that includes an AbstractThing field + fields: + override_me: + type: bs:AbstractThing + jsonldPredicate: "bs:override_me" + + +- type: record + name: ExtendedContainer + extends: AbstractContainer + doc: | + An extended version of the abstract container that implements an extra field + and uses an ExtendedThing to override the original field + fields: + extra_field: + type: + type: array + items: [string] + override_me: + type: dv:ExtendedThing + jsonldPredicate: "bs:override_me" diff --git a/schema_salad/tests/test_schema/avro_subtype_nested_bad.yml b/schema_salad/tests/test_schema/avro_subtype_nested_bad.yml new file mode 100644 index 00000000..cac2feda --- /dev/null +++ b/schema_salad/tests/test_schema/avro_subtype_nested_bad.yml @@ -0,0 +1,35 @@ +$base: "https://example.com/nested_schema#" + +$namespaces: + bs: "https://example.com/base_schema#" + dv: "https://example.com/derived_schema#" + +$graph: + +- $import: avro_subtype_bad.yml + +- type: record + name: AbstractContainer + abstract: true + doc: | + This is an abstract container thing that includes an AbstractThing field + fields: + override_me: + type: bs:AbstractThing + jsonldPredicate: "bs:override_me" + + +- type: record + name: ExtendedContainer + extends: AbstractContainer + doc: | + An extended version of the abstract container that implements an extra field + and uses an ExtendedThing to override the original field + fields: + extra_field: + type: + type: array + items: [string] + override_me: + type: dv:ExtendedThing + jsonldPredicate: "bs:override_me" diff --git a/schema_salad/tests/test_schema/avro_subtype_recursive.yml b/schema_salad/tests/test_schema/avro_subtype_recursive.yml new file mode 100644 index 00000000..709a4c58 --- /dev/null +++ b/schema_salad/tests/test_schema/avro_subtype_recursive.yml @@ -0,0 +1,32 @@ +$base: "https://example.com/recursive_schema#" + +$namespaces: + bs: "https://example.com/base_schema#" + +$graph: + +- $import: "metaschema_base.yml" + +- type: record + name: RecursiveThing + doc: | + This is an arbitrary recursive thing that includes itself in its fields + fields: + override_me: + type: RecursiveThing + jsonldPredicate: "bs:override_me" + + +- type: record + name: ExtendedThing + extends: RecursiveThing + doc: | + An extended version of the recursive thing that implements an extra field + fields: + field_one: + type: + type: array + items: [string] + override_me: + type: ExtendedThing + jsonldPredicate: "bs:override_me" diff --git a/schema_salad/tests/test_schema/avro_subtype_union.yml b/schema_salad/tests/test_schema/avro_subtype_union.yml new file mode 100644 index 00000000..c2534c3b --- /dev/null +++ b/schema_salad/tests/test_schema/avro_subtype_union.yml @@ -0,0 +1,36 @@ +$base: "https://example.com/union_schema#" + +$namespaces: + bs: "https://example.com/base_schema#" + dv: "https://example.com/derived_schema#" + +$graph: + +- $import: avro_subtype.yml + +- type: record + name: AbstractContainer + abstract: true + doc: | + This is an abstract container thing that includes an AbstractThing + type in its field types + fields: + override_me: + type: [int, string, bs:AbstractThing] + jsonldPredicate: "bs:override_me" + + +- type: record + name: ExtendedContainer + extends: AbstractContainer + doc: | + An extended version of the abstract container that implements an extra field + and contains an ExtendedThing type in its overridden field types + fields: + extra_field: + type: + type: array + items: [string] + override_me: + type: [int, dv:ExtendedThing] + jsonldPredicate: "bs:override_me" diff --git a/schema_salad/tests/test_schema/avro_subtype_union_bad.yml b/schema_salad/tests/test_schema/avro_subtype_union_bad.yml new file mode 100644 index 00000000..8bfedd08 --- /dev/null +++ b/schema_salad/tests/test_schema/avro_subtype_union_bad.yml @@ -0,0 +1,36 @@ +$base: "https://example.com/union_schema#" + +$namespaces: + bs: "https://example.com/base_schema#" + dv: "https://example.com/derived_schema#" + +$graph: + +- $import: avro_subtype_bad.yml + +- type: record + name: AbstractContainer + abstract: true + doc: | + This is an abstract container thing that includes an AbstractThing + type in its field types + fields: + override_me: + type: [int, string, bs:AbstractThing] + jsonldPredicate: "bs:override_me" + + +- type: record + name: ExtendedContainer + extends: AbstractContainer + doc: | + An extended version of the abstract container that implements an extra field + and contains an ExtendedThing type in its overridden field types + fields: + extra_field: + type: + type: array + items: [string] + override_me: + type: [int, dv:ExtendedThing] + jsonldPredicate: "bs:override_me" diff --git a/schema_salad/tests/test_subtypes.py b/schema_salad/tests/test_subtypes.py index 24803236..17b3ca17 100644 --- a/schema_salad/tests/test_subtypes.py +++ b/schema_salad/tests/test_subtypes.py @@ -1,10 +1,10 @@ """Confirm subtypes.""" + import pytest from schema_salad.avro import schema from schema_salad.avro.schema import Names, SchemaParseException from schema_salad.schema import load_schema - from .util import get_data types = [ @@ -84,7 +84,7 @@ @pytest.mark.parametrize("old,new,result", types) def test_subtypes(old: schema.PropType, new: schema.PropType, result: bool) -> None: """Test is_subtype() function.""" - assert schema.is_subtype(old, new) == result + assert schema.is_subtype({}, old, new) == result def test_avro_loading_subtype() -> None: @@ -105,4 +105,55 @@ def test_avro_loading_subtype_bad() -> None: r"Any vs \['string', 'int'\]\." ) with pytest.raises(SchemaParseException, match=target_error): - document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path) + _ = load_schema(path) + + +def test_subtypes_nested() -> None: + """Confirm correct subtype handling on a nested type definition.""" + path = get_data("tests/test_schema/avro_subtype_nested.yml") + assert path + document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path) + assert isinstance(avsc_names, Names) + assert avsc_names.get_name("com.example.nested_schema.ExtendedContainer", None) + + +def test_subtypes_nested_bad() -> None: + """Confirm subtype error when overriding incorrectly in nested types.""" + path = get_data("tests/test_schema/avro_subtype_nested_bad.yml") + assert path + target_error = ( + r"Field name .*\/override_me already in use with incompatible type. " + r"Any vs \['string', 'int'\]\." + ) + with pytest.raises(SchemaParseException, match=target_error): + _ = load_schema(path) + + +def test_subtypes_recursive() -> None: + """Confirm correct subtype handling on a recursive type definition.""" + path = get_data("tests/test_schema/avro_subtype_recursive.yml") + assert path + document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path) + assert isinstance(avsc_names, Names) + assert avsc_names.get_name("com.example.recursive_schema.RecursiveThing", None) + + +def test_subtypes_union() -> None: + """Confirm correct subtype handling on an union type definition.""" + path = get_data("tests/test_schema/avro_subtype_union.yml") + assert path + document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path) + assert isinstance(avsc_names, Names) + assert avsc_names.get_name("com.example.union_schema.ExtendedContainer", None) + + +def test_subtypes_union_bad() -> None: + """Confirm subtype error when overriding incorrectly in array types.""" + path = get_data("tests/test_schema/avro_subtype_union_bad.yml") + assert path + target_error = ( + r"Field name .*\/override_me already in use with incompatible type. " + r"Any vs \['string', 'int'\]\." + ) + with pytest.raises(SchemaParseException, match=target_error): + _ = load_schema(path)