Skip to content

Commit

Permalink
feat: display results in 'not found' error
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Oct 30, 2024
1 parent 93af61b commit ef2131e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._utils import key_name
from .data_graph import DataGraph, to_task_graph
from .display import pipeline_html_repr
from .handler import ErrorHandler, HandleAsComputeTimeException
from .handler import ErrorHandler, HandleAsComputeTimeException, UnsatisfiedRequirement
from .scheduler import Scheduler
from .task_graph import TaskGraph
from .typing import Key
Expand Down Expand Up @@ -107,7 +107,7 @@ def visualize(
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.leafs()
tp = self.final_result_keys()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)

def get(
Expand Down Expand Up @@ -140,7 +140,11 @@ def get(
targets = tuple(keys) # type: ignore[arg-type]
else:
targets = (keys,) # type: ignore[assignment]
graph = to_task_graph(self, targets=targets, handler=handler)
try:
graph = to_task_graph(self, targets=targets, handler=handler)
except UnsatisfiedRequirement as e:
final_result_keys = ", ".join(map(repr, self.final_result_keys()))
raise type(e)(f'Did you meant one of: {final_result_keys}?') from e
return TaskGraph(
graph=graph,
targets=targets if multi else keys, # type: ignore[arg-type]
Expand Down Expand Up @@ -205,7 +209,7 @@ def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self.underlying_graph.nodes.items())
return pipeline_html_repr(nodes)

def leafs(self) -> tuple[type, ...]:
def final_result_keys(self) -> tuple[type, ...]:
"""Returns the keys that are not inputs to any other providers."""
sink_nodes = [
cast(type, node)
Expand Down
20 changes: 18 additions & 2 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import sciline as sl
from sciline._utils import key_name


def int_to_float(x: int) -> float:
Expand Down Expand Up @@ -1416,11 +1417,26 @@ def bad(x: int, y: int) -> float:
pipeline.insert(bad)


def test_leafs_method() -> None:
def test_final_result_keys_method() -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

assert sl.Pipeline([make_float, make_str]).leafs() == (float, str)
assert sl.Pipeline([make_float, make_str]).final_result_keys() == (float, str)


@pytest.mark.parametrize('get_method', ['get', 'compute'])
def test_final_result_keys_in_not_found_error_message(get_method) -> None:
def make_float() -> float:
return 1.0

def make_str(x: int) -> str:
return "a string"

pl = sl.Pipeline([make_float, make_str])
with pytest.raises(sl.handler.UnsatisfiedRequirement) as info:
getattr(pl, get_method)(int)
for key in pl.final_result_keys():
assert key_name(key) in info.value.args[0]

0 comments on commit ef2131e

Please sign in to comment.