Skip to content

Commit

Permalink
Node check restructuring (#2518)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Jan 26, 2024
1 parent f0947e5 commit 36c01c1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 104 deletions.
1 change: 1 addition & 0 deletions backend/src/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .group import *
from .input import *
from .node_context import *
from .node_data import *
from .output import *
from .settings import *
from .types import *
103 changes: 11 additions & 92 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
Callable,
Generic,
Iterable,
TypedDict,
TypeVar,
)

from sanic.log import logger

import navi

from .group import Group, GroupId, NestedGroup, NestedIdGroup
from .input import BaseInput
from .node_check import (
Expand All @@ -26,6 +23,7 @@
check_naming_conventions,
check_schema_types,
)
from .node_data import IteratorInputInfo, IteratorOutputInfo, NodeData
from .output import BaseOutput
from .settings import Setting
from .types import FeatureId, InputId, NodeId, NodeKind, OutputId, RunFn
Expand Down Expand Up @@ -72,84 +70,6 @@ def _process_outputs(base_outputs: Iterable[BaseOutput]):
return outputs


class DefaultNode(TypedDict):
schemaId: str


class IteratorInputInfo:
def __init__(
self,
inputs: int | InputId | list[int] | list[InputId] | list[int | InputId],
length_type: navi.ExpressionJson = "uint",
) -> None:
self.inputs: list[InputId] = (
[InputId(x) for x in inputs]
if isinstance(inputs, list)
else [InputId(inputs)]
)
self.length_type: navi.ExpressionJson = length_type

def to_dict(self):
return {
"inputs": self.inputs,
"lengthType": self.length_type,
}


class IteratorOutputInfo:
def __init__(
self,
outputs: int | OutputId | list[int] | list[OutputId] | list[int | OutputId],
length_type: navi.ExpressionJson = "uint",
) -> None:
self.outputs: list[OutputId] = (
[OutputId(x) for x in outputs]
if isinstance(outputs, list)
else [OutputId(outputs)]
)
self.length_type: navi.ExpressionJson = length_type

def to_dict(self):
return {
"outputs": self.outputs,
"lengthType": self.length_type,
}


@dataclass(frozen=True)
class NodeData:
schema_id: str
description: str
see_also: list[str]
name: str
icon: str
kind: NodeKind

inputs: list[BaseInput]
outputs: list[BaseOutput]
group_layout: list[InputId | NestedIdGroup]

iterator_inputs: list[IteratorInputInfo]
iterator_outputs: list[IteratorOutputInfo]

side_effects: bool
deprecated: bool
node_context: bool
features: list[FeatureId]

run: RunFn

@property
def single_iterator_input(self) -> IteratorInputInfo:
assert len(self.iterator_inputs) == 1
return self.iterator_inputs[0]

@property
def single_iterator_output(self) -> IteratorOutputInfo:
assert len(self.iterator_outputs) == 1
return self.iterator_outputs[0]


T = TypeVar("T", bound=RunFn)
S = TypeVar("S")

Expand Down Expand Up @@ -242,17 +162,7 @@ def inner_wrapper(wrapped_func: T) -> T:
p_inputs, group_layout = _process_inputs(inputs)
p_outputs = _process_outputs(outputs)

if kind == "regularNode":
run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(
wrapped_func, p_inputs, p_outputs, node_context
),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(wrapped_func, name, fix),
)
original_fn = wrapped_func

if decorators is not None:
for decorator in decorators:
Expand All @@ -277,6 +187,15 @@ def inner_wrapper(wrapped_func: T) -> T:
run=wrapped_func,
)

run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(original_fn, node),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(original_fn, name, fix),
)

self.add_node(node)
return wrapped_func

Expand Down
22 changes: 10 additions & 12 deletions backend/src/api/node_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from enum import Enum
from typing import Any, Callable, NewType, Tuple, Union, cast, get_args

from .input import BaseInput
from .node_context import NodeContext
from .output import BaseOutput
from .node_data import NodeData

_Ty = NewType("_Ty", object)

Expand Down Expand Up @@ -148,11 +147,8 @@ def get_type_annotations(fn: Callable) -> dict[str, _Ty]:
return type_annotations


def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]):
if str(return_type).startswith("api.Iterator["):
return_type = get_args(return_type)[0]
elif str(return_type).startswith("api.Collector["):
return_type = get_args(return_type)[1]
def validate_return_type(return_type: _Ty, node: NodeData):
outputs = node.outputs

if len(outputs) == 0:
if return_type is not None and return_type is not type(None): # type: ignore
Expand Down Expand Up @@ -190,27 +186,28 @@ def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]):

def check_schema_types(
wrapped_func: Callable,
inputs: list[BaseInput],
outputs: list[BaseOutput],
node_context: bool,
node: NodeData,
):
"""
Runtime validation for the number of inputs/outputs compared to the type args
"""

if node.kind != "regularNode":
return

ann = OrderedDict(get_type_annotations(wrapped_func))

# check return type
if "return" in ann:
validate_return_type(ann.pop("return"), outputs)
validate_return_type(ann.pop("return"), node)

# check arguments
arg_spec = inspect.getfullargspec(wrapped_func)
for arg in arg_spec.args:
if arg not in ann:
raise CheckFailedError(f"Missing type annotation for '{arg}'")

if node_context:
if node.node_context:
first = arg_spec.args[0]
if first != "context":
raise CheckFailedError(
Expand All @@ -223,6 +220,7 @@ def check_schema_types(
)

# check inputs
inputs = node.inputs

if arg_spec.varargs is not None:
if arg_spec.varargs not in ann:
Expand Down
84 changes: 84 additions & 0 deletions backend/src/api/node_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from dataclasses import dataclass

import navi

from .group import NestedIdGroup
from .input import BaseInput
from .output import BaseOutput
from .types import FeatureId, InputId, NodeKind, OutputId, RunFn


class IteratorInputInfo:
def __init__(
self,
inputs: int | InputId | list[int] | list[InputId] | list[int | InputId],
length_type: navi.ExpressionJson = "uint",
) -> None:
self.inputs: list[InputId] = (
[InputId(x) for x in inputs]
if isinstance(inputs, list)
else [InputId(inputs)]
)
self.length_type: navi.ExpressionJson = length_type

def to_dict(self):
return {
"inputs": self.inputs,
"lengthType": self.length_type,
}


class IteratorOutputInfo:
def __init__(
self,
outputs: int | OutputId | list[int] | list[OutputId] | list[int | OutputId],
length_type: navi.ExpressionJson = "uint",
) -> None:
self.outputs: list[OutputId] = (
[OutputId(x) for x in outputs]
if isinstance(outputs, list)
else [OutputId(outputs)]
)
self.length_type: navi.ExpressionJson = length_type

def to_dict(self):
return {
"outputs": self.outputs,
"lengthType": self.length_type,
}


@dataclass(frozen=True)
class NodeData:
schema_id: str
description: str
see_also: list[str]
name: str
icon: str
kind: NodeKind

inputs: list[BaseInput]
outputs: list[BaseOutput]
group_layout: list[InputId | NestedIdGroup]

iterator_inputs: list[IteratorInputInfo]
iterator_outputs: list[IteratorOutputInfo]

side_effects: bool
deprecated: bool
node_context: bool
features: list[FeatureId]

run: RunFn

@property
def single_iterator_input(self) -> IteratorInputInfo:
assert len(self.iterator_inputs) == 1
return self.iterator_inputs[0]

@property
def single_iterator_output(self) -> IteratorOutputInfo:
assert len(self.iterator_outputs) == 1
return self.iterator_outputs[0]

0 comments on commit 36c01c1

Please sign in to comment.