Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construct a writer tree #40

Merged
merged 10 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions pyiceberg/avro/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyiceberg.avro.decoder import BinaryDecoder, new_decoder
from pyiceberg.avro.encoder import BinaryEncoder
from pyiceberg.avro.reader import Reader
from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve
from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve_reader, resolve_writer
from pyiceberg.avro.writer import Writer
from pyiceberg.io import InputFile, OutputFile, OutputStream
from pyiceberg.schema import Schema
Expand Down Expand Up @@ -172,7 +172,7 @@ def __enter__(self) -> AvroFile[D]:
if not self.read_schema:
self.read_schema = self.schema

self.reader = resolve(self.schema, self.read_schema, self.read_types, self.read_enums)
self.reader = resolve_reader(self.schema, self.read_schema, self.read_types, self.read_enums)

return self

Expand Down Expand Up @@ -222,18 +222,29 @@ def _read_header(self) -> AvroFileHeader:
class AvroOutputFile(Generic[D]):
output_file: OutputFile
output_stream: OutputStream
schema: Schema
file_schema: Schema
schema_name: str
encoder: BinaryEncoder
sync_bytes: bytes
writer: Writer

def __init__(self, output_file: OutputFile, schema: Schema, schema_name: str, metadata: Dict[str, str] = EMPTY_DICT) -> None:
def __init__(
self,
output_file: OutputFile,
file_schema: Schema,
schema_name: str,
record_schema: Optional[Schema] = None,
metadata: Dict[str, str] = EMPTY_DICT,
) -> None:
self.output_file = output_file
self.schema = schema
self.file_schema = file_schema
self.schema_name = schema_name
self.sync_bytes = os.urandom(SYNC_SIZE)
self.writer = construct_writer(self.schema)
self.writer = (
construct_writer(file_schema=self.file_schema)
if record_schema is None
else resolve_writer(record_schema=record_schema, file_schema=self.file_schema)
)
self.metadata = metadata

def __enter__(self) -> AvroOutputFile[D]:
Expand All @@ -247,7 +258,6 @@ def __enter__(self) -> AvroOutputFile[D]:
self.encoder = BinaryEncoder(self.output_stream)

self._write_header()
self.writer = construct_writer(self.schema)

return self

Expand All @@ -258,7 +268,7 @@ def __exit__(
self.output_stream.close()

def _write_header(self) -> None:
json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.schema, schema_name=self.schema_name))
json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.file_schema, schema_name=self.schema_name))
meta = {**self.metadata, _SCHEMA_KEY: json_schema, _CODEC_KEY: "null"}
header = AvroFileHeader(magic=MAGIC, meta=meta, sync=self.sync_bytes)
construct_writer(META_SCHEMA).write(self.encoder, header)
Expand Down
125 changes: 118 additions & 7 deletions pyiceberg/avro/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
BooleanWriter,
DateWriter,
DecimalWriter,
DefaultWriter,
DoubleWriter,
FixedWriter,
FloatWriter,
Expand Down Expand Up @@ -112,11 +113,12 @@ def construct_reader(

Args:
file_schema (Schema | IcebergType): The schema of the Avro file.
read_types (Dict[int, Callable[..., StructProtocol]]): Constructors for structs for certain field-ids

Raises:
NotImplementedError: If attempting to resolve an unrecognized object type.
"""
return resolve(file_schema, file_schema, read_types)
return resolve_reader(file_schema, file_schema, read_types)


def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer:
Expand All @@ -128,7 +130,7 @@ def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer:
Raises:
NotImplementedError: If attempting to resolve an unrecognized object type.
"""
return visit(file_schema, ConstructWriter())
return visit(file_schema, CONSTRUCT_WRITER_VISITOR)


class ConstructWriter(SchemaVisitorPerPrimitiveType[Writer]):
Expand All @@ -138,7 +140,7 @@ def schema(self, schema: Schema, struct_result: Writer) -> Writer:
return struct_result

def struct(self, struct: StructType, field_results: List[Writer]) -> Writer:
return StructWriter(tuple(field_results))
return StructWriter(tuple((pos, result) for pos, result in enumerate(field_results)))

def field(self, field: NestedField, field_result: Writer) -> Writer:
return field_result if field.required else OptionWriter(field_result)
Expand Down Expand Up @@ -192,7 +194,28 @@ def visit_binary(self, binary_type: BinaryType) -> Writer:
return BinaryWriter()


def resolve(
CONSTRUCT_WRITER_VISITOR = ConstructWriter()


def resolve_writer(
record_schema: Union[Schema, IcebergType],
file_schema: Union[Schema, IcebergType],
) -> Writer:
"""Resolve the file and read schema to produce a reader.

Args:
record_schema (Schema | IcebergType): The schema of the record in memory.
file_schema (Schema | IcebergType): The schema of the file that will be written
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!


Raises:
NotImplementedError: If attempting to resolve an unrecognized object type.
"""
if record_schema == file_schema:
return construct_writer(file_schema)
return visit_with_partner(file_schema, record_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore


def resolve_reader(
file_schema: Union[Schema, IcebergType],
read_schema: Union[Schema, IcebergType],
read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT,
Expand All @@ -210,7 +233,7 @@ def resolve(
NotImplementedError: If attempting to resolve an unrecognized object type.
"""
return visit_with_partner(
file_schema, read_schema, SchemaResolver(read_types, read_enums), SchemaPartnerAccessor()
file_schema, read_schema, ReadSchemaResolver(read_types, read_enums), SchemaPartnerAccessor()
) # type: ignore


Expand All @@ -233,7 +256,95 @@ def skip(self, decoder: BinaryDecoder) -> None:
pass


class SchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]):
class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]):
def schema(self, file_schema: Schema, record_schema: Optional[IcebergType], result: Writer) -> Writer:
return result

def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], file_writers: List[Writer]) -> Writer:
if not isinstance(record_struct, StructType):
raise ResolveError(f"File/write schema are not aligned for struct, got {record_struct}")

record_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(record_struct.fields)}
results: List[Tuple[Optional[int], Writer]] = []

for writer, file_field in zip(file_writers, file_schema.fields):
if file_field.field_id in record_struct_positions:
results.append((record_struct_positions[file_field.field_id], writer))
elif file_field.required:
# There is a default value
if file_field.write_default is not None:
# The field is not in the record, but there is a write default value
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore
elif file_field.required:
raise ValueError(f"Field is required, and there is no write default: {file_field}")
else:
results.append((None, writer))

return StructWriter(field_writers=tuple(results))

def field(self, file_field: NestedField, record_type: Optional[IcebergType], field_writer: Writer) -> Writer:
return field_writer if file_field.required else OptionWriter(field_writer)

def list(self, file_list_type: ListType, file_list: Optional[IcebergType], element_writer: Writer) -> Writer:
return ListWriter(element_writer if file_list_type.element_required else OptionWriter(element_writer))

def map(
self, file_map_type: MapType, file_primitive: Optional[IcebergType], key_writer: Writer, value_writer: Writer
) -> Writer:
return MapWriter(key_writer, value_writer if file_map_type.value_required else OptionWriter(value_writer))

def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[IcebergType]) -> Writer:
if record_primitive is not None:
# ensure that the type can be projected to the expected
if file_primitive != record_primitive:
promote(record_primitive, file_primitive)

return super().primitive(file_primitive, file_primitive)

def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer:
return BooleanWriter()

def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Writer:
return IntegerWriter()

def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Writer:
return IntegerWriter()

def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Writer:
return FloatWriter()

def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Writer:
return DoubleWriter()
rdblue marked this conversation as resolved.
Show resolved Hide resolved

def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Writer:
return DecimalWriter(decimal_type.precision, decimal_type.scale)

def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Writer:
return DateWriter()

def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Writer:
return TimeWriter()

def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Writer:
return TimestampWriter()

def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Writer:
return TimestamptzWriter()

def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Writer:
return StringWriter()

def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Writer:
return UUIDWriter()

def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Writer:
return FixedWriter(len(fixed_type))

def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Writer:
return BinaryWriter()


class ReadSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]):
__slots__ = ("read_types", "read_enums", "context")
read_types: Dict[int, Callable[..., StructProtocol]]
read_enums: Dict[int, Callable[..., Enum]]
Expand Down Expand Up @@ -279,7 +390,7 @@ def struct(self, struct: StructType, expected_struct: Optional[IcebergType], fie
for field, result_reader in zip(struct.fields, field_readers)
]

file_fields = {field.field_id: field for field in struct.fields}
file_fields = {field.field_id for field in struct.fields}
for pos, read_field in enumerate(expected_struct.fields):
if read_field.field_id not in file_fields:
if isinstance(read_field, NestedField) and read_field.initial_default is not None:
Expand Down
36 changes: 27 additions & 9 deletions pyiceberg/avro/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Any,
Dict,
List,
Optional,
Tuple,
)
from uuid import UUID
Expand All @@ -39,6 +40,7 @@
from pyiceberg.utils.singleton import Singleton


@dataclass(frozen=True)
class Writer(Singleton):
@abstractmethod
def write(self, encoder: BinaryEncoder, val: Any) -> Any:
Expand All @@ -49,58 +51,63 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class NoneWriter(Writer):
def write(self, _: BinaryEncoder, __: Any) -> None:
pass


@dataclass(frozen=True)
class BooleanWriter(Writer):
def write(self, encoder: BinaryEncoder, val: bool) -> None:
encoder.write_boolean(val)


@dataclass(frozen=True)
class IntegerWriter(Writer):
"""Longs and ints are encoded the same way, and there is no long in Python."""

def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)


@dataclass(frozen=True)
class FloatWriter(Writer):
def write(self, encoder: BinaryEncoder, val: float) -> None:
encoder.write_float(val)


@dataclass(frozen=True)
class DoubleWriter(Writer):
def write(self, encoder: BinaryEncoder, val: float) -> None:
encoder.write_double(val)


@dataclass(frozen=True)
class DateWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)


@dataclass(frozen=True)
class TimeWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)


@dataclass(frozen=True)
class TimestampWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)


@dataclass(frozen=True)
class TimestamptzWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)


@dataclass(frozen=True)
class StringWriter(Writer):
def write(self, encoder: BinaryEncoder, val: Any) -> None:
encoder.write_utf8(val)


@dataclass(frozen=True)
class UUIDWriter(Writer):
def write(self, encoder: BinaryEncoder, val: UUID) -> None:
encoder.write(val.bytes)
Expand All @@ -124,6 +131,7 @@ def __repr__(self) -> str:
return f"FixedWriter({self._len})"


@dataclass(frozen=True)
class BinaryWriter(Writer):
"""Variable byte length writer."""

Expand Down Expand Up @@ -158,19 +166,20 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None:

@dataclass(frozen=True)
class StructWriter(Writer):
field_writers: Tuple[Writer, ...] = dataclassfield()
field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield()

def write(self, encoder: BinaryEncoder, val: Record) -> None:
for writer, value in zip(self.field_writers, val.record_fields()):
writer.write(encoder, value)
for pos, writer in self.field_writers:
# When pos is None, then it is a default value
writer.write(encoder, val[pos] if pos is not None else None)

def __eq__(self, other: Any) -> bool:
"""Implement the equality operator for this object."""
return self.field_writers == other.field_writers if isinstance(other, StructWriter) else False

def __repr__(self) -> str:
"""Return string representation of this object."""
return f"StructWriter({','.join(repr(field) for field in self.field_writers)})"
return f"StructWriter(tuple(({','.join(repr(field) for field in self.field_writers)})))"

def __hash__(self) -> int:
"""Return the hash of the writer as hash of this object."""
Expand Down Expand Up @@ -201,3 +210,12 @@ def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None:
self.value_writer.write(encoder, v)
if len(val) > 0:
encoder.write_int(0)


@dataclass(frozen=True)
class DefaultWriter(Writer):
writer: Writer
value: Any

def write(self, encoder: BinaryEncoder, _: Any) -> None:
self.writer.write(encoder, self.value)
Loading