Skip to content

Commit

Permalink
Add metadata to missing nodes
Browse files Browse the repository at this point in the history
Also fix metadata read write tests
  • Loading branch information
bvogginger committed Aug 15, 2024
1 parent b9b359a commit a53c3e3
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 10 deletions.
1 change: 1 addition & 0 deletions nir/ir/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Conv2d(NIRNode):
dilation: Union[int, Tuple[int, int]] # Dilation
groups: int # Groups
bias: np.ndarray # Bias C_out
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if isinstance(self.padding, str) and self.padding not in ["same", "valid"]:
Expand Down
2 changes: 2 additions & 0 deletions nir/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ class Input(NIRNode):
# Shape of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
input_type: Types
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = parse_shape_argument(self.input_type, "input")
Expand Down Expand Up @@ -479,6 +480,7 @@ class Output(NIRNode):
# Type of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
output_type: Types
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.output_type = parse_shape_argument(self.output_type, "output")
Expand Down
2 changes: 2 additions & 0 deletions nir/ir/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Linear(NIRNode):
"""

weight: np.ndarray # Weight term
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
Expand All @@ -69,6 +70,7 @@ class Scale(NIRNode):
"""

scale: np.ndarray # Scaling factor
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": np.array(self.scale.shape)}
Expand Down
1 change: 1 addition & 0 deletions nir/ir/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AvgPool2d(NIRNode):
kernel_size: np.ndarray # (Height, Width)
stride: np.ndarray # (Height, width)
padding: np.ndarray # (Height, width)
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": None}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ find={include = ["nir*"]}

[tool.ruff]
line-length = 100
lint.per-file-ignores = {"docs/conf.py" = ["E402"]}
per-file-ignores = {"docs/conf.py" = ["E402"]}
exclude = ["paper/"]
20 changes: 11 additions & 9 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import tempfile
import inspect
import sys
import tempfile

import numpy as np

import nir
from tests import mock_affine, mock_conv
from tests import mock_affine, mock_conv, mock_linear

ALL_NODES = []
for name, obj in inspect.getmembers(sys.modules["nir.ir"]):
Expand Down Expand Up @@ -47,7 +47,7 @@ def factory_test_graph(ir: nir.NIRGraph):
assert_equivalence(ir, ir2)


def factory_test_metadata(node):
def factory_test_metadata(ir: nir.NIRGraph):
def compare_dicts(d1, d2):
for k, v in d1.items():
if isinstance(v, np.ndarray):
Expand All @@ -58,12 +58,14 @@ def compare_dicts(d1, d2):
assert v == d2[k]

metadata = {"some": "metadata", "with": 2, "data": np.array([1, 2, 3])}
node.metadata = metadata
compare_dicts(node.metadata, metadata)
for node in ir.nodes.values():
node.metadata = metadata
compare_dicts(node.metadata, metadata)
tmp = tempfile.mktemp()
nir.write(tmp, node)
node2 = nir.read(tmp)
compare_dicts(node2.metadata, metadata)
nir.write(tmp, ir)
ir2 = nir.read(tmp)
for node in ir2.nodes.values():
compare_dicts(node.metadata, metadata)


def test_simple():
Expand Down Expand Up @@ -146,7 +148,7 @@ def test_linear():
tau = np.array([1, 1, 1])
r = np.array([1, 1, 1])
v_leak = np.array([1, 1, 1])
ir = nir.NIRGraph.from_list(mock_affine(2, 2), nir.LI(tau, r, v_leak))
ir = nir.NIRGraph.from_list(mock_linear(2, 2), nir.LI(tau, r, v_leak))
factory_test_graph(ir)
factory_test_metadata(ir)

Expand Down

0 comments on commit a53c3e3

Please sign in to comment.