Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make records hashable #107

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flow/record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from flow.record.base import (
IGNORE_FIELDS_FOR_COMPARISON,
RECORD_VERSION,
RECORDSTREAM_MAGIC,
DynamicDescriptor,
Expand All @@ -21,6 +22,7 @@
open_path_or_stream,
open_stream,
stream,
update_ignored_fields_for_comparison,
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved
)
from flow.record.jsonpacker import JsonRecordPacker
from flow.record.stream import (
Expand All @@ -35,6 +37,7 @@
)

__all__ = [
"IGNORE_FIELDS_FOR_COMPARISON",
"RECORD_VERSION",
"RECORDSTREAM_MAGIC",
"FieldType",
Expand All @@ -55,6 +58,7 @@
"open_path",
"open_stream",
"stream",
"update_ignored_fields_for_comparison",
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved
"dynamic_fieldtype",
"DynamicDescriptor",
"PathTemplateWriter",
Expand Down
25 changes: 23 additions & 2 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ def _unpack(__cls, {args}):
"""


if env_excluded_fields := os.environ.get("FLOW_RECORD_IGNORE"):
IGNORE_FIELDS_FOR_COMPARISON = env_excluded_fields.split(",")
else:
IGNORE_FIELDS_FOR_COMPARISON = []
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved


def update_ignored_fields_for_comparison(ignored_fields: list[str]) -> None:
"""Can be used to update the IGNORE_FIELDS_FOR_COMPARISON from outside the flow.record package scope"""
global IGNORE_FIELDS_FOR_COMPARISON
IGNORE_FIELDS_FOR_COMPARISON = ignored_fields
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved


class FieldType:
def _typename(self):
t = type(self)
Expand All @@ -117,14 +129,20 @@ class Record:
def __eq__(self, other):
if not isinstance(other, Record):
return False
return self._pack() == other._pack()

def _pack(self, unversioned=False):
return self._pack(excluded_fields=IGNORE_FIELDS_FOR_COMPARISON) == other._pack(
excluded_fields=IGNORE_FIELDS_FOR_COMPARISON
)

def _pack(self, unversioned=False, excluded_fields: list = None):
values = []
for k in self.__slots__:
v = getattr(self, k)
v = v._pack() if isinstance(v, FieldType) else v

if excluded_fields and k in excluded_fields:
continue

# Skip version field if requested (only for compatibility reasons)
if unversioned and k == "_version" and v == 1:
continue
Expand Down Expand Up @@ -160,6 +178,9 @@ def _replace(self, **kwds):
raise ValueError("Got unexpected field names: {kwds!r}".format(kwds=list(kwds)))
return result

def __hash__(self) -> int:
return self._pack(excluded_fields=IGNORE_FIELDS_FOR_COMPARISON).__hash__()
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
return "<{} {}>".format(
self._desc.name, " ".join("{}={!r}".format(k, getattr(self, k)) for k in self._desc.fields)
Expand Down
66 changes: 65 additions & 1 deletion tests/test_record.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import importlib
import os
import sys
from unittest.mock import patch

import pytest

Expand All @@ -15,7 +18,11 @@
fieldtypes,
record_stream,
)
from flow.record.base import merge_record_descriptors, normalize_fieldname
from flow.record.base import (
merge_record_descriptors,
normalize_fieldname,
update_ignored_fields_for_comparison,
)
from flow.record.exceptions import RecordDescriptorError
from flow.record.stream import RecordFieldRewriter

Expand Down Expand Up @@ -792,3 +799,60 @@ def test_normalize_fieldname():
assert normalize_fieldname("my name (with) parentheses") == "my_name__with__parentheses"
assert normalize_fieldname("_generated") == "_generated"
assert normalize_fieldname("_source") == "_source"


def test_compare_global_variable():
TestRecord = RecordDescriptor(
"test/record",
[
("string", "firstname"),
("string", "lastname"),
],
)

same_same = TestRecord(firstname="James", lastname="Bond")
but_different = TestRecord(firstname="Ethan", lastname="Hunt")
but_still_same = TestRecord(firstname="Andrew", lastname="Bond")

records = [same_same, but_different, but_still_same]

assert same_same != but_still_same

update_ignored_fields_for_comparison(["_generated", "firstname"])
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved
assert same_same == but_still_same
assert same_same != but_different
assert len(set(records)) == 2


def test_compare_environment_variable():
with patch.dict(os.environ), patch.dict(sys.modules):
os.environ["FLOW_RECORD_IGNORE"] = "_generated,lastname"

# Force a re-import of flow.record so the global variable gets re-initialized based on the environment variable
keys = [key for key in sys.modules if key == "flow" or "flow." in key]
for key in keys:
del sys.modules[key]

importlib.import_module("flow.record")

from flow.record import IGNORE_FIELDS_FOR_COMPARISON, RecordDescriptor

assert IGNORE_FIELDS_FOR_COMPARISON == ["_generated", "lastname"]
JSCU-CNI marked this conversation as resolved.
Show resolved Hide resolved

TestRecord = RecordDescriptor(
"test/record",
[
("string", "firstname"),
("string", "lastname"),
],
)

same_same = TestRecord(firstname="John", lastname="Rambo")
but_different = TestRecord(firstname="Johnny", lastname="English")
but_still_same = TestRecord(firstname="John", lastname="McClane")

records = [same_same, but_different, but_still_same]

assert same_same == but_still_same
assert same_same != but_different
assert len(set(records)) == 2
Loading