Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Changes for PR
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Aug 15, 2022
1 parent 88ffd59 commit a055fe8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 33 deletions.
6 changes: 3 additions & 3 deletions hamilton/data_quality/default_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ def arg(cls) -> str:
return 'mean_in_range'


class NoNonesAllowedValidatorAllTypes(BaseDefaultValidator):
class AllowNoneValidator(BaseDefaultValidator):

def __init__(self, allow_none: bool, importance: str):
super(NoNonesAllowedValidatorAllTypes, self).__init__(importance)
super(AllowNoneValidator, self).__init__(importance)
self.allow_none = allow_none

@classmethod
Expand Down Expand Up @@ -416,7 +416,7 @@ def arg(cls) -> str:
MaxFractionNansValidatorPandasSeries,
MaxStandardDevValidatorPandasSeries,
MeanInRangeValidatorPandasSeries,
NoNonesAllowedValidatorAllTypes,
AllowNoneValidator,
]


Expand Down
51 changes: 25 additions & 26 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@

import typing_inspect

import hamilton.function_modifiers_base
from hamilton import function_modifiers_base
from hamilton import node
from hamilton.node import NodeSource, DependencyType
from hamilton import base
from hamilton.type_utils import custom_subclass_check
from hamilton import type_utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,7 +46,7 @@ def types_match(adapter: base.HamiltonGraphAdapter,
return required_node_type == param_type
elif required_node_type == param_type:
return True
elif custom_subclass_check(required_node_type, param_type):
elif type_utils.custom_subclass_check(required_node_type, param_type):
return True
elif adapter.check_node_type_equivalence(required_node_type, param_type):
return True
Expand Down Expand Up @@ -91,7 +90,7 @@ def add_dependency(
f'{param_name}:{required_node.type}. All names & types must match.')
else:
# this is a user defined var
required_node = node.Node(param_name, param_type, node_source=NodeSource.EXTERNAL)
required_node = node.Node(param_name, param_type, node_source=node.NodeSource.EXTERNAL)
nodes[param_name] = required_node
# add edges
func_node.dependencies.append(required_node)
Expand All @@ -111,7 +110,7 @@ def create_function_graph(*modules: ModuleType, config: Dict[str, Any], adapter:

# create nodes -- easier to just create this in one loop
for func_name, f in functions:
for n in hamilton.function_modifiers_base.resolve_nodes(f, config):
for n in function_modifiers_base.resolve_nodes(f, config):
if n.name in config:
continue # This makes sure we overwrite things if they're in the config...
if n.name in nodes:
Expand All @@ -124,7 +123,7 @@ def create_function_graph(*modules: ModuleType, config: Dict[str, Any], adapter:
add_dependency(n, node_name, nodes, param_name, param_type, adapter)
for key in config.keys():
if key not in nodes:
nodes[key] = node.Node(key, Any, node_source=NodeSource.EXTERNAL)
nodes[key] = node.Node(key, Any, node_source=node.NodeSource.EXTERNAL)
return nodes


Expand Down Expand Up @@ -318,7 +317,7 @@ def next_nodes_function(n: node.Node) -> List[node.Node]:
# If inputs is None, we want to assume its required, as it is a compile-time dependency
if dep.user_defined and dep.name not in runtime_inputs and dep.name not in self.config:
_, dependency_type = n.input_types[dep.name]
if dependency_type == DependencyType.OPTIONAL:
if dependency_type == node.DependencyType.OPTIONAL:
continue
deps.append(dep)
return deps
Expand Down Expand Up @@ -383,41 +382,41 @@ def execute_static(nodes: Collection[node.Node],
if computed is None:
computed = {}

def dfs_traverse(node: node.Node, dependency_type: DependencyType = DependencyType.REQUIRED):
if node.name in computed:
def dfs_traverse(node_: node.Node, dependency_type: node.DependencyType = node.DependencyType.REQUIRED):
if node_.name in computed:
return
if node.name in overrides:
computed[node.name] = overrides[node.name]
if node_.name in overrides:
computed[node_.name] = overrides[node_.name]
return
for n in node.dependencies:
for n in node_.dependencies:
if n.name not in computed:
_, node_dependency_type = node.input_types[n.name]
_, node_dependency_type = node_.input_types[n.name]
dfs_traverse(n, node_dependency_type)

logger.debug(f'Computing {node.name}.')
if node.user_defined:
if node.name not in inputs:
if dependency_type != DependencyType.OPTIONAL:
raise NotImplementedError(f'{node.name} was expected to be passed in but was not.')
logger.debug(f'Computing {node_.name}.')
if node_.user_defined:
if node_.name not in inputs:
if dependency_type != node.DependencyType.OPTIONAL:
raise NotImplementedError(f'{node_.name} was expected to be passed in but was not.')
return
value = inputs[node.name]
value = inputs[node_.name]
else:
kwargs = {} # construct signature
for dependency in node.dependencies:
for dependency in node_.dependencies:
if dependency.name in computed:
kwargs[dependency.name] = computed[dependency.name]
try:
value = adapter.execute_node(node, kwargs)
value = adapter.execute_node(node_, kwargs)
except Exception as e:
logger.exception(f'Node {node.name} encountered an error')
logger.exception(f'Node {node_.name} encountered an error')
raise
computed[node.name] = value
computed[node_.name] = value

for final_var_node in nodes:
dep_type = DependencyType.REQUIRED
dep_type = node.DependencyType.REQUIRED
if final_var_node.user_defined:
# from the top level, we don't know if this UserInput is required. So mark as optional.
dep_type = DependencyType.OPTIONAL
dep_type = node.DependencyType.OPTIONAL
dfs_traverse(final_var_node, dep_type)
return computed

Expand Down
8 changes: 4 additions & 4 deletions tests/test_default_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def test_resolve_default_validators_error(output_type, kwargs, importance):
(default_validators.AllowNaNsValidatorPandasSeries, False, pd.Series([.1, None]), False),
(default_validators.AllowNaNsValidatorPandasSeries, False, pd.Series([.1, .2]), True),
(default_validators.NoNonesAllowedValidatorAllTypes, False, None, False),
(default_validators.NoNonesAllowedValidatorAllTypes, False, 1, True),
(default_validators.NoNonesAllowedValidatorAllTypes, True, None, True),
(default_validators.NoNonesAllowedValidatorAllTypes, True, 1, True),
(default_validators.AllowNoneValidator, False, None, False),
(default_validators.AllowNoneValidator, False, 1, True),
(default_validators.AllowNoneValidator, True, None, True),
(default_validators.AllowNoneValidator, True, 1, True),
]
)
def test_default_data_validators(cls: Type[hamilton.data_quality.base.BaseDefaultValidator], param: Any, data: Any, should_pass: bool):
Expand Down

0 comments on commit a055fe8

Please sign in to comment.