Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make records hashable #107

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flow/record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from flow.record.base import (
IGNORE_FIELDS_FOR_COMPARISON,
RECORD_VERSION,
RECORDSTREAM_MAGIC,
DynamicDescriptor,
Expand All @@ -20,6 +21,7 @@
open_path,
open_path_or_stream,
open_stream,
set_ignored_fields_for_comparison,
stream,
)
from flow.record.jsonpacker import JsonRecordPacker
Expand All @@ -35,6 +37,7 @@
)

__all__ = [
"IGNORE_FIELDS_FOR_COMPARISON",
"RECORD_VERSION",
"RECORDSTREAM_MAGIC",
"FieldType",
Expand All @@ -54,6 +57,7 @@
"open_path_or_stream",
"open_path",
"open_stream",
"set_ignored_fields_for_comparison",
"stream",
"dynamic_fieldtype",
"DynamicDescriptor",
Expand Down
37 changes: 34 additions & 3 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
from datetime import datetime, timezone
from itertools import zip_longest
from pathlib import Path
from typing import IO, Any, BinaryIO, Iterator, Mapping, Optional, Sequence, Union
from typing import (
IO,
Any,
BinaryIO,
Iterable,
Iterator,
Mapping,
Optional,
Sequence,
Union,
)
from urllib.parse import parse_qsl, urlparse

from flow.record.adapter import AbstractReader, AbstractWriter
Expand Down Expand Up @@ -96,6 +106,18 @@ def _unpack(__cls, {args}):
"""


if env_excluded_fields := os.environ.get("FLOW_RECORD_IGNORE"):
IGNORE_FIELDS_FOR_COMPARISON = set(env_excluded_fields.split(","))
else:
IGNORE_FIELDS_FOR_COMPARISON = set()


def set_ignored_fields_for_comparison(ignored_fields: Iterable[str]) -> None:
"""Can be used to update the IGNORE_FIELDS_FOR_COMPARISON from outside the flow.record package scope"""
global IGNORE_FIELDS_FOR_COMPARISON
IGNORE_FIELDS_FOR_COMPARISON = set(ignored_fields)


class FieldType:
def _typename(self):
t = type(self)
Expand All @@ -117,14 +139,20 @@ class Record:
def __eq__(self, other):
if not isinstance(other, Record):
return False
return self._pack() == other._pack()

def _pack(self, unversioned=False):
return self._pack(excluded_fields=IGNORE_FIELDS_FOR_COMPARISON) == other._pack(
excluded_fields=IGNORE_FIELDS_FOR_COMPARISON
)

def _pack(self, unversioned=False, excluded_fields: list = None):
values = []
for k in self.__slots__:
v = getattr(self, k)
v = v._pack() if isinstance(v, FieldType) else v

if excluded_fields and k in excluded_fields:
continue

# Skip version field if requested (only for compatibility reasons)
if unversioned and k == "_version" and v == 1:
continue
Expand Down Expand Up @@ -160,6 +188,9 @@ def _replace(self, **kwds):
raise ValueError("Got unexpected field names: {kwds!r}".format(kwds=list(kwds)))
return result

def __hash__(self) -> int:
return hash(self._pack(excluded_fields=IGNORE_FIELDS_FOR_COMPARISON))

def __repr__(self):
return "<{} {}>".format(
self._desc.name, " ".join("{}={!r}".format(k, getattr(self, k)) for k in self._desc.fields)
Expand Down
66 changes: 65 additions & 1 deletion tests/test_record.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import importlib
import os
import sys
from unittest.mock import patch

import pytest

Expand All @@ -15,7 +18,11 @@
fieldtypes,
record_stream,
)
from flow.record.base import merge_record_descriptors, normalize_fieldname
from flow.record.base import (
merge_record_descriptors,
normalize_fieldname,
set_ignored_fields_for_comparison,
)
from flow.record.exceptions import RecordDescriptorError
from flow.record.stream import RecordFieldRewriter

Expand Down Expand Up @@ -792,3 +799,60 @@ def test_normalize_fieldname():
assert normalize_fieldname("my name (with) parentheses") == "my_name__with__parentheses"
assert normalize_fieldname("_generated") == "_generated"
assert normalize_fieldname("_source") == "_source"


def test_compare_global_variable():
TestRecord = RecordDescriptor(
"test/record",
[
("string", "firstname"),
("string", "lastname"),
],
)

same_same = TestRecord(firstname="James", lastname="Bond")
but_different = TestRecord(firstname="Ethan", lastname="Hunt")
but_still_same = TestRecord(firstname="Andrew", lastname="Bond")

records = [same_same, but_different, but_still_same]

assert same_same != but_still_same

set_ignored_fields_for_comparison({"_generated", "firstname"})
assert same_same == but_still_same
assert same_same != but_different
assert len(set(records)) == 2


def test_compare_environment_variable():
with patch.dict(os.environ), patch.dict(sys.modules):
os.environ["FLOW_RECORD_IGNORE"] = "_generated,lastname"

# Force a re-import of flow.record so the global variable gets re-initialized based on the environment variable
keys = [key for key in sys.modules if key == "flow" or "flow." in key]
for key in keys:
del sys.modules[key]

importlib.import_module("flow.record")

from flow.record import IGNORE_FIELDS_FOR_COMPARISON, RecordDescriptor

assert IGNORE_FIELDS_FOR_COMPARISON == {"_generated", "lastname"}

TestRecord = RecordDescriptor(
"test/record",
[
("string", "firstname"),
("string", "lastname"),
],
)

same_same = TestRecord(firstname="John", lastname="Rambo")
but_different = TestRecord(firstname="Johnny", lastname="English")
but_still_same = TestRecord(firstname="John", lastname="McClane")

records = [same_same, but_different, but_still_same]

assert same_same == but_still_same
assert same_same != but_different
assert len(set(records)) == 2
Loading