Skip to content

Commit

Permalink
Support file-like inputs for RecordReader (#59)
Browse files Browse the repository at this point in the history
Peek into the file to find the right adapter by checking the file magic

---------

Co-authored-by: Max Groot <max.groot@fox-it.com>
Co-authored-by: Yun Zheng Hu <hu@fox-it.com>
Co-authored-by: Erik Schamper <1254028+Schamper@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 13, 2023
1 parent 6358ba3 commit 2e2eb62
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 114 deletions.
6 changes: 6 additions & 0 deletions flow/record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from flow.record.base import (
RECORD_VERSION,
RECORDSTREAM_MAGIC,
DynamicDescriptor,
FieldType,
GroupedRecord,
Expand All @@ -16,7 +17,9 @@
dynamic_fieldtype,
extend_record,
iter_timestamped_records,
open_file,
open_path,
open_stream,
stream,
)
from flow.record.jsonpacker import JsonRecordPacker
Expand All @@ -33,6 +36,7 @@

__all__ = [
"RECORD_VERSION",
"RECORDSTREAM_MAGIC",
"FieldType",
"Record",
"GroupedRecord",
Expand All @@ -47,7 +51,9 @@
"JsonRecordPacker",
"RecordStreamWriter",
"RecordStreamReader",
"open_file",
"open_path",
"open_stream",
"stream",
"dynamic_fieldtype",
"DynamicDescriptor",
Expand Down
21 changes: 12 additions & 9 deletions flow/record/adapter/avro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import json
from datetime import datetime, timedelta, timezone
from importlib.util import find_spec
from typing import Any, Iterator

import fastavro

Expand Down Expand Up @@ -50,15 +53,15 @@ class AvroWriter(AbstractWriter):
writer = None

def __init__(self, path, key=None, **kwargs):
self.fp = record.open_path(path, "wb")
self.fp = record.open_file(path, "wb")

self.desc = None
self.schema = None
self.parsed_schema = None
self.writer = None
self.codec = "snappy" if find_spec("snappy") else "deflate"

def write(self, r):
def write(self, r: record.Record) -> None:
if not self.desc:
self.desc = r._desc
self.schema = descriptor_to_schema(self.desc)
Expand All @@ -79,7 +82,7 @@ def flush(self):
)
self.writer.flush()

def close(self):
def close(self) -> None:
if self.fp and not is_stdout(self.fp):
self.fp.close()
self.fp = None
Expand All @@ -90,7 +93,7 @@ class AvroReader(AbstractReader):
fp = None

def __init__(self, path, selector=None, **kwargs):
self.fp = record.open_path(path, "rb")
self.fp = record.open_file(path, "rb")
self.selector = make_selector(selector)

self.reader = fastavro.reader(self.fp)
Expand All @@ -105,7 +108,7 @@ def __init__(self, path, selector=None, **kwargs):
name for name, field in self.desc.get_all_fields().items() if field.typename == "datetime"
)

def __iter__(self):
def __iter__(self) -> Iterator[record.Record]:
for obj in self.reader:
# Convert timestamp-micros fields back to datetime fields
for field_name in self.datetime_fields:
Expand All @@ -117,13 +120,13 @@ def __iter__(self):
if not self.selector or self.selector.match(rec):
yield rec

def close(self):
def close(self) -> None:
if self.fp:
self.fp.close()
self.fp = None


def descriptor_to_schema(desc):
def descriptor_to_schema(desc: record.RecordDescriptor) -> dict[str, Any]:
namespace, _, name = desc.name.rpartition("/")
schema = {
"type": "record",
Expand Down Expand Up @@ -156,7 +159,7 @@ def descriptor_to_schema(desc):
return schema


def schema_to_descriptor(schema):
def schema_to_descriptor(schema: dict) -> record.RecordDescriptor:
doc = schema.get("doc")

# Sketchy record descriptor detection
Expand All @@ -178,7 +181,7 @@ def schema_to_descriptor(schema):
return record.RecordDescriptor(name, fields)


def avro_type_to_flow_type(ftype):
def avro_type_to_flow_type(ftype: list) -> str:
ftypes = [ftype] if not isinstance(ftype, list) else ftype

# If a field can be null, it has an additional type of "null"
Expand Down
29 changes: 16 additions & 13 deletions flow/record/adapter/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from flow import record
from typing import Iterator, Union

from flow.record import Record, RecordOutput, RecordStreamReader, open_file, open_path
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.selector import Selector
from flow.record.utils import is_stdout

__usage__ = """
Expand All @@ -15,20 +18,20 @@ class StreamWriter(AbstractWriter):
fp = None
stream = None

def __init__(self, path, clobber=True, **kwargs):
self.fp = record.open_path(path, "wb", clobber=clobber)
self.stream = record.RecordOutput(self.fp)
def __init__(self, path: str, clobber=True, **kwargs):
self.fp = open_path(path, "wb", clobber=clobber)
self.stream = RecordOutput(self.fp)

def write(self, r):
self.stream.write(r)
def write(self, record: Record) -> None:
self.stream.write(record)

def flush(self):
def flush(self) -> None:
if self.stream and hasattr(self.stream, "flush"):
self.stream.flush()
if self.fp:
self.fp.flush()

def close(self):
def close(self) -> None:
if self.stream:
self.stream.close()
self.stream = None
Expand All @@ -42,14 +45,14 @@ class StreamReader(AbstractReader):
fp = None
stream = None

def __init__(self, path, selector=None, **kwargs):
self.fp = record.open_path(path, "rb")
self.stream = record.RecordStreamReader(self.fp, selector=selector)
def __init__(self, path: str, selector: Union[str, Selector] = None, **kwargs):
self.fp = open_file(path, "rb")
self.stream = RecordStreamReader(self.fp, selector=selector)

def __iter__(self):
def __iter__(self) -> Iterator[Record]:
return iter(self.stream)

def close(self):
def close(self) -> None:
if self.stream:
self.stream.close()
self.stream = None
Expand Down
Loading

0 comments on commit 2e2eb62

Please sign in to comment.