Skip to content

Commit

Permalink
Hashing algorithm to respect row order
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Jun 14, 2024
1 parent 327ee95 commit 96a74b6
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 53 deletions.
133 changes: 84 additions & 49 deletions audformat/core/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,61 +1099,36 @@ def _save_csv(self, path: str):
df.to_csv(fp, encoding="utf-8")

def _save_parquet(self, path: str):
df = self.df.reset_index()
r"""Save table as PARQUET file.
table = pa.Table.from_pandas(df, preserve_index=False)
A PARQUET file is written in a non-deterministic way,
and we cannot track changes by its MD5 sum.
To make changes trackable,
we store a hash in its metadata.
# Add hash of dataframe
# to the metadata,
# which pyarrow stores inside the schema.
# See https://stackoverflow.com/a/58978449.
#
# This allows us to track if a PARQUET file changes over time.
# We cannot rely on md5 sums of the file,
# as the file is written in a non-deterministic way.
table_hash = hashlib.md5()
The hash is calculated from the pyarrow schema
(to track column names and data types)
and the pandas dataframes
(to track values and order or rows),
from which the PARQUET file is generated.
# Hash of schema (columns + dtypes)
schema_str = table.schema.to_string(
# schema.metadata contains pandas related information,
# and the used pyarrow and pandas version,
# and needs to be excluded
show_field_metadata=False,
show_schema_metadata=False,
)
schema_hash = hashlib.md5(schema_str.encode())
table_hash.update(schema_hash.digest())
The hash of the PARQUET can then be read by::
# Hash data
try:
data_hash = utils.hash(self.df)
except TypeError:
# Levels/columns with dtype "object" might not be hashable,
# e.g. when storing numpy arrays.
# We convert them to strings in this case.

# Index
df = self.df.copy()
update_index_dtypes = {
level: "string"
for level, dtype in self._levels_and_dtypes.items()
if dtype == define.DataType.OBJECT
}
df.index = utils.set_index_dtypes(df.index, update_index_dtypes)

# Columns
for column_id, column in self.columns.items():
if column.scheme_id is not None:
scheme = self.db.schemes[column.scheme_id]
if scheme.dtype == define.DataType.OBJECT:
df[column_id] = df[column_id].astype("string")
else:
# No scheme defaults to `object` dtype
df[column_id] = df[column_id].astype("string")
data_hash = utils.hash(df)
pyarrow.parquet.read_schema(path).metadata[b"hash"].decode()
Args:
path: path, including file extension
table_hash.update(data_hash.encode())
"""
table = pa.Table.from_pandas(self.df.reset_index(), preserve_index=False)

# Create hash of table
table_hash = hashlib.md5()
table_hash.update(_schema_hash(table))
table_hash.update(_dataframe_hash(self.df))

# Store in metadata of file,
# see https://stackoverflow.com/a/58978449
metadata = {"hash": table_hash.hexdigest()}
table = table.replace_schema_metadata({**metadata, **table.schema.metadata})

Expand Down Expand Up @@ -1855,6 +1830,46 @@ def _assert_table_index(
)


def _dataframe_hash(df: pd.DataFrame, max_rows: int = None) -> bytes:
"""Hash a dataframe.
The hash value takes into account:
* index of dataframe
* values of the dataframe
* order of dataframe rows
It does not consider:
* column names of dataframe
* dtypes of dataframe
Args:
df: dataframe
max_rows: if not ``None``,
the maximum number of rows,
taken into account for hashing
Returns:
MD5 hash in bytes
"""
# Idea for implementation from
# https://github.com/streamlit/streamlit/issues/7086#issuecomment-1654504410
md5 = hashlib.md5()
if max_rows is not None and len(df) > max_rows: # pragma: nocover (not yet used)
df = df.sample(n=max_rows, random_state=0)
# Hash length, as we have to track if this changes
md5.update(str(len(df)).encode("utf-8"))
try:
md5.update(bytes(str(pd.util.hash_pandas_object(df)), "utf-8"))
except TypeError:
# Use pickle if pandas cannot hash the object,
# e.g. if it contains numpy.arrays.
md5.update(f"{pickle.dumps(df, pickle.HIGHEST_PROTOCOL)}".encode("utf-8"))
return md5.digest()


def _maybe_convert_dtype_to_string(
index: pd.Index,
) -> pd.Index:
Expand All @@ -1877,3 +1892,23 @@ def _maybe_update_scheme(
for scheme in table.db.schemes.values():
if table._id == scheme.labels:
scheme.replace_labels(table._id)


def _schema_hash(table: pa.Table) -> bytes:
r"""Hash pyarrow table schema.
Args:
table: pyarrow table
Returns:
MD5 hash in bytes
"""
schema_str = table.schema.to_string(
# schema.metadata contains pandas related information,
# and the used pyarrow and pandas version,
# and needs to be excluded
show_field_metadata=False,
show_schema_metadata=False,
)
return hashlib.md5(schema_str.encode()).digest()
104 changes: 100 additions & 4 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,24 +1210,120 @@ def test_map(table, map):
pd.testing.assert_frame_equal(result, expected)


@pytest.mark.parametrize("storage_format", ["csv", "parquet"])
def test_hash(tmpdir, storage_format):
r"""Test if PARQUET file hash changes with table.
We store a MD5 sum associated with the dataframe,
that was used to create the file,
in the metadata of the PARQUET file.
Those MD5 sum is supposed to change,
if any of the table rows, (index) columns changes,
the data type of the entries changes,
or the name of a column changes.
Args:
tmpdir: tmpdir fixture
storage_format: storage format of table file
"""

def get_md5(path: str) -> str:
r"""Get MD5 sum for table file."""
ext = audeer.file_extension(path)
if ext == "csv":
md5 = audeer.md5(path)
elif ext == "parquet":
md5 = parquet.read_schema(path).metadata[b"hash"].decode()
return md5

db_root = audeer.path(tmpdir, "db")
db = audformat.Database("mydb")
db.schemes["int"] = audformat.Scheme("int")
index = audformat.segmented_index(["f1", "f2"], [0, 1], [1, 2])
db["table"] = audformat.Table(index)
db["table"]["column"] = audformat.Column(scheme_id="int")
db["table"]["column"].set([0, 1])
db.save(db_root, storage_format=storage_format)

table_file = audeer.path(db_root, f"db.table.{storage_format}")
assert os.path.exists(table_file)
md5 = get_md5(table_file)

# Replace table with identical copy
table = db["table"].copy()
db["table"] = table
db.save(db_root, storage_format=storage_format)
assert get_md5(table_file) == md5

# Change order of rows
index = audformat.segmented_index(["f2", "f1"], [1, 0], [2, 1])
db["table"] = audformat.Table(index)
db["table"]["column"] = audformat.Column(scheme_id="int")
db["table"]["column"].set([1, 0])
db.save(db_root, storage_format=storage_format)
assert get_md5(table_file) != md5

# Change index entry
index = audformat.segmented_index(["f1", "f1"], [0, 1], [1, 2])
db["table"] = audformat.Table(index)
db["table"]["column"] = audformat.Column(scheme_id="int")
db["table"]["column"].set([0, 1])
db.save(db_root, storage_format=storage_format)
assert get_md5(table_file) != md5

# Change data entry
index = audformat.segmented_index(["f1", "f2"], [0, 1], [1, 2])
db["table"] = audformat.Table(index)
db["table"]["column"] = audformat.Column(scheme_id="int")
db["table"]["column"].set([1, 0])
db.save(db_root, storage_format=storage_format)
assert get_md5(table_file) != md5

# Change column name
index = audformat.segmented_index(["f1", "f2"], [0, 1], [1, 2])
db["table"] = audformat.Table(index)
db["table"]["col"] = audformat.Column(scheme_id="int")
db["table"]["col"].set([0, 1])
db.save(db_root, storage_format=storage_format)
assert get_md5(table_file) != md5

# Change order of columns
index = audformat.segmented_index(["f1", "f2"], [0, 1], [1, 2])
db["table"] = audformat.Table(index)
db["table"]["col1"] = audformat.Column(scheme_id="int")
db["table"]["col1"].set([0, 1])
db["table"]["col2"] = audformat.Column(scheme_id="int")
db["table"]["col2"].set([0, 1])
db.save(db_root, storage_format=storage_format)
md5 = get_md5(table_file)
db["table"] = audformat.Table(index)
db["table"]["col2"] = audformat.Column(scheme_id="int")
db["table"]["col2"].set([0, 1])
db["table"]["col1"] = audformat.Column(scheme_id="int")
db["table"]["col1"].set([0, 1])
db.save(db_root, storage_format=storage_format)
assert get_md5(table_file) != md5


@pytest.mark.parametrize(
"table_id, expected_hash",
[
(
"files",
"4d0295654694751bdcd12be86b89b73e",
"9caa6722e65a04ddbce1cda2238c9126",
),
(
"segments",
"d2a9b84d03abde24ae84cf647a019b71",
"37c9d9dc4f937a6e97ec72a080055e49",
),
(
"misc",
"6b6faecc836354bd89472095c1fa746a",
"3488c007d45b19e04e8fdbf000f0f04d",
),
],
)
def test_parquet_reproducibility(tmpdir, table_id, expected_hash):
def test_parquet_hash_reproducibility(tmpdir, table_id, expected_hash):
r"""Test reproducibility of binary PARQUET files.
When storing the same dataframe
Expand Down

0 comments on commit 96a74b6

Please sign in to comment.