Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Add pandas index checks #200

Merged
merged 5 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 165 additions & 14 deletions hamilton/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import abc
import collections
import inspect
import typing
import logging
from typing import Any, Dict, List, Tuple, Type, Union

import numpy as np
import pandas as pd
import typing_inspect
from pandas.core.indexes import extension as pd_extension

from . import node

logger = logging.getLogger(__name__)


class ResultMixin(object):
"""Base class housing the static function.
Expand All @@ -23,7 +27,7 @@ class ResultMixin(object):

@staticmethod
@abc.abstractmethod
def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Any:
def build_result(**outputs: Dict[str, Any]) -> Any:
"""This function builds the result given the computed values."""
pass

Expand All @@ -32,7 +36,7 @@ class DictResult(ResultMixin):
"""Simple function that returns the dict of column -> value results."""

@staticmethod
def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Dict:
def build_result(**outputs: Dict[str, Any]) -> Dict:
"""This function builds a simple dict of output -> computed values."""
return outputs

Expand All @@ -41,9 +45,122 @@ class PandasDataFrameResult(ResultMixin):
"""Mixin for building a pandas dataframe from the result"""

@staticmethod
def build_result(**outputs: typing.Dict[str, typing.Any]) -> pd.DataFrame:
def pandas_index_types(
outputs: Dict[str, Any]
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]], Dict[str, List[str]]]:
"""This function creates three dictionaries according to whether there is an index type or not.

The three dicts we create are:
1. Dict of index type to list of outputs that match it.
2. Dict of time series / categorical index types to list of outputs that match it.
3. Dict of `no-index` key to list of outputs with no index type.

:param outputs: the dict we're trying to create a result from.
:return: dict of all index types, dict of time series/categorical index types, dict if there is no index
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""
all_index_types = collections.defaultdict(list)
time_indexes = collections.defaultdict(list)
no_indexes = collections.defaultdict(list)

def index_key_name(pd_object: Union[pd.DataFrame, pd.Series]) -> str:
"""Creates a string helping identify the index and it's type.
Useful for disambiguating time related indexes."""
return f"{pd_object.index.__class__.__name__}:::{pd_object.index.dtype}"

def get_parent_time_index_type():
"""Helper to pull the right time index parent class."""
if hasattr(
pd_extension, "NDArrayBackedExtensionIndex"
): # for python 3.7+ & pandas >= 1.2
index_type = pd_extension.NDArrayBackedExtensionIndex
elif hasattr(pd_extension, "ExtensionIndex"): # for python 3.6 & pandas <= 1.2
index_type = pd_extension.ExtensionIndex
else:
index_type = None # weird case, but not worth breaking for.
return index_type

for output_name, output_value in outputs.items():
if isinstance(
output_value, (pd.DataFrame, pd.Series)
): # if it has an index -- let's grab it's type
dict_key = index_key_name(output_value)
if isinstance(output_value.index, get_parent_time_index_type()):
# it's a time index -- these will produce garbage if not aligned properly.
time_indexes[dict_key].append(output_name)
elif isinstance(
output_value, pd.Index
): # there is no index on this - so it's just an integer one.
int_index = pd.Series(
[1, 2, 3], index=[0, 1, 2]
) # dummy to get right values for string.
dict_key = index_key_name(int_index)
else:
dict_key = "no-index"
no_indexes[dict_key].append(output_name)
all_index_types[dict_key].append(output_name)
return all_index_types, time_indexes, no_indexes

@staticmethod
def check_pandas_index_types_match(
all_index_types: Dict[str, List[str]],
time_indexes: Dict[str, List[str]],
no_indexes: Dict[str, List[str]],
) -> bool:
"""Checks that pandas index types match.

This only logs warning errors, and if debug is enabled, a debug statement to list index types.
"""
no_index_length = len(no_indexes)
time_indexes_length = len(time_indexes)
all_indexes_length = len(all_index_types)
number_with_indexes = all_indexes_length - no_index_length
types_match = True # default to True
# if there is more than one time index
if time_indexes_length > 1:
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
f"WARNING: Time/Categorical index type mismatches detected - check output to ensure Pandas "
f"is doing what you intend to do. Else change the index types to match. Set logger to debug "
f"to see index types."
)
types_match = False
# if there is more than one index type and it's not explained by the time indexes then
if number_with_indexes > 1 and all_indexes_length > time_indexes_length:
logger.warning(
f"WARNING: Multiple index types detected - check output to ensure Pandas is "
f"doing what you intend to do. Else change the index types to match. Set logger to debug to "
f"see index types."
)
types_match = False
elif number_with_indexes == 1 and no_index_length > 0:
logger.warning(
f"WARNING: a single pandas index was found, but there are also {no_index_length} outputs without "
f"an index. Those values will be made constants throughout the values of the index."
)
# Strictly speaking the index types match -- there is only one -- so setting to True.
types_match = True
# if all indexes matches no indexes
elif no_index_length == all_indexes_length:
logger.warning(
"It appears no Pandas index type was detected. This will likely break when trying to "
"create a DataFrame. E.g. are you requesting all scalar values? Use a different result "
"builder or return at least one Pandas object with an index."
)
types_match = False
if logger.isEnabledFor(logging.DEBUG):
import pprint

pretty_string = pprint.pformat(dict(all_index_types))
logger.debug(f"Index types encountered:\n{pretty_string}.")
return types_match

@staticmethod
def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame:
# TODO check inputs are pd.Series, arrays, or scalars -- else error
# TODO do a basic index check across pd.Series and flag where mismatches occur?
output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs)
# this next line just log warnings
# we don't actually care about the result since this is the current default behavior.
PandasDataFrameResult.check_pandas_index_types_match(*output_index_type_tuple)

if len(outputs) == 1:
(value,) = outputs.values() # this works because it's length 1.
if isinstance(value, pd.DataFrame):
Expand All @@ -54,14 +171,48 @@ def build_result(**outputs: typing.Dict[str, typing.Any]) -> pd.DataFrame:
return pd.DataFrame(outputs)


class StrictIndexTypePandasDataFrameResult(PandasDataFrameResult):
"""A ResultBuilder that produces a dataframe only if the index types match exactly.

Note: If there is no index type on some outputs, e.g. the value is a scalar, as long as there exists a single pandas
index type, no error will be thrown, because a dataframe can be easily created.

To use:
from hamilton import base, driver
strict_builder = base.StrictIndexTypePandasDataFrameResult()
adapter = base.SimplePythonGraphAdapter(strict_builder)
...
dr = driver.Driver(config, *modules, adapter=adapter)
df = dr.execute(...) # this will now error if index types mismatch.
"""

@staticmethod
def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame:
# TODO check inputs are pd.Series, arrays, or scalars -- else error
output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs)
indexes_match = PandasDataFrameResult.check_pandas_index_types_match(
*output_index_type_tuple
)
if not indexes_match:
import pprint

pretty_string = pprint.pformat(dict(output_index_type_tuple[0]))
raise ValueError(
"Error: pandas index types did not match exactly. "
f"Found the following indexes:\n{pretty_string}"
)

return PandasDataFrameResult.build_result(**outputs)


class NumpyMatrixResult(ResultMixin):
"""Mixin for building a Numpy Matrix from the result of walking the graph.

All inputs to the build_result function are expected to be numpy arrays
"""

@staticmethod
def build_result(**outputs: typing.Dict[str, typing.Any]) -> np.matrix:
def build_result(**outputs: Dict[str, Any]) -> np.matrix:
"""Builds a numpy matrix from the passed in, inputs.

:param outputs: function_name -> np.array.
Expand Down Expand Up @@ -108,7 +259,7 @@ class HamiltonGraphAdapter(ResultMixin):

@staticmethod
@abc.abstractmethod
def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
def check_input_type(node_type: Type, input_value: Any) -> bool:
"""Used to check whether the user inputs match what the execution strategy & functions can handle.

:param node_type: The type of the node.
Expand All @@ -119,7 +270,7 @@ def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:

@staticmethod
@abc.abstractmethod
def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool:
"""Used to check whether two types are equivalent.

This is used when the function graph is being created and we're statically type checking the annotations
Expand All @@ -132,7 +283,7 @@ def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type)
pass

@abc.abstractmethod
def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any:
"""Given a node that represents a hamilton function, execute it.
Note, in some adapters this might just return some type of "future".

Expand All @@ -147,8 +298,8 @@ class SimplePythonDataFrameGraphAdapter(HamiltonGraphAdapter, PandasDataFrameRes
"""This is the default (original Hamilton) graph adapter. It uses plain python and builds a dataframe result."""

@staticmethod
def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
if node_type == typing.Any:
def check_input_type(node_type: Type, input_value: Any) -> bool:
if node_type == Any:
return True
elif inspect.isclass(node_type) and isinstance(input_value, node_type):
return True
Expand All @@ -171,10 +322,10 @@ def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
return False

@staticmethod
def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool:
return node_type == input_type

def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any:
return node.callable(**kwargs)


Expand All @@ -186,6 +337,6 @@ def __init__(self, result_builder: ResultMixin):
if self.result_builder is None:
raise ValueError("You must provide a ResultMixin object for `result_builder`.")

def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
def build_result(self, **outputs: Dict[str, Any]) -> Any:
"""Delegates to the result builder function supplied."""
return self.result_builder.build_result(**outputs)
Loading