Skip to content

Commit

Permalink
fix: field id in name mapping should be optional
Browse files Browse the repository at this point in the history
  • Loading branch information
barronw committed Dec 16, 2024
1 parent ede363b commit afc36f8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
18 changes: 12 additions & 6 deletions pyiceberg/table/name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


class MappedField(IcebergBaseModel):
field_id: int = Field(alias="field-id")
field_id: Optional[int] = Field(alias="field-id", default=None)
names: List[str] = conlist(str)
fields: List[MappedField] = Field(default_factory=list)

Expand All @@ -49,9 +49,10 @@ def convert_null_to_empty_List(cls, v: Any) -> Any:
@model_serializer
def ser_model(self) -> Dict[str, Any]:
"""Set custom serializer to leave out the field when it is empty."""
field_id = {"field-id": self.field_id} if self.field_id is not None else {}
fields = {"fields": self.fields} if len(self.fields) > 0 else {}
return {
"field-id": self.field_id,
**field_id,
"names": self.names,
**fields,
}
Expand All @@ -65,7 +66,8 @@ def __str__(self) -> str:
# Otherwise the UTs fail because the order of the set can change
fields_str = ", ".join([str(e) for e in self.fields]) or ""
fields_str = " " + fields_str if fields_str else ""
return "([" + ", ".join(self.names) + "] -> " + (str(self.field_id) or "?") + fields_str + ")"
field_id = "?" if self.field_id is None else (str(self.field_id) or "?")
return "([" + ", ".join(self.names) + "] -> " + field_id + fields_str + ")"


class NameMapping(IcebergRootModel[List[MappedField]]):
Expand Down Expand Up @@ -232,7 +234,9 @@ def mapping(self, nm: NameMapping, field_results: List[MappedField]) -> List[Map

def fields(self, struct: List[MappedField], field_results: List[MappedField]) -> List[MappedField]:
reassignments: Dict[str, int] = {
update.name: update.field_id for f in field_results if (update := self._updates.get(f.field_id))
update.name: update.field_id
for f in field_results
if f.field_id is not None and (update := self._updates.get(f.field_id))
}
return [
updated_field
Expand All @@ -241,6 +245,8 @@ def fields(self, struct: List[MappedField], field_results: List[MappedField]) ->
]

def field(self, field: MappedField, field_result: List[MappedField]) -> MappedField:
if field.field_id is None:
return field
field_names = field.names
if (update := self._updates.get(field.field_id)) is not None and update.name not in field_names:
field_names.append(update.name)
Expand Down Expand Up @@ -333,8 +339,8 @@ def struct(self, struct: StructType, struct_partner: Optional[MappedField], fiel
return StructType(*field_results)

def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType:
if field_partner is None:
raise ValueError(f"Field missing from NameMapping: {'.'.join(self.current_path)}")
if field_partner is None or field_partner.field_id is None:
raise ValueError(f"Field or field ID missing from NameMapping: {'.'.join(self.current_path)}")

return NestedField(
field_id=field_partner.field_id,
Expand Down
28 changes: 28 additions & 0 deletions tests/table/test_name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def test_json_mapped_field_no_names_deserialization() -> None:
assert MappedField(field_id=1, names=[]) == MappedField.model_validate_json(mapped_field_with_null_fields)


def test_json_mapped_field_no_field_id_deserialization() -> None:
mapped_field = """{
"names": []
}
"""
assert MappedField(field_id=None, names=[]) == MappedField.model_validate_json(mapped_field)

mapped_field_with_null_fields = """{
"names": [],
"fields": null
}
"""
assert MappedField(names=[]) == MappedField.model_validate_json(mapped_field_with_null_fields)


def test_json_name_mapping_deserialization() -> None:
name_mapping = """
[
Expand Down Expand Up @@ -164,6 +179,19 @@ def test_json_name_mapping_deserialization() -> None:
])


def test_json_mapped_field_no_field_id_serialization() -> None:
table_name_mapping_nested_no_field_id = NameMapping([
MappedField(field_id=1, names=["foo"]),
MappedField(field_id=None, names=["bar"]),
MappedField(field_id=2, names=["qux"], fields=[MappedField(field_id=None, names=["element"])]),
])

assert (
table_name_mapping_nested_no_field_id.model_dump_json()
== """[{"field-id":1,"names":["foo"]},{"names":["bar"]},{"field-id":2,"names":["qux"],"fields":[{"names":["element"]}]}]"""
)


def test_json_serialization(table_name_mapping_nested: NameMapping) -> None:
assert (
table_name_mapping_nested.model_dump_json()
Expand Down

0 comments on commit afc36f8

Please sign in to comment.