Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding fill library node #1664

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ Yihang Luo
Alexandru Calotoiu
Phillip Lane
Samuel Martin
Krutarth Patel

and other contributors listed in https://github.com/spcl/dace/graphs/contributors
54 changes: 54 additions & 0 deletions dace/libraries/standard/nodes/fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dace
from dace import library, nodes, properties
from dace.transformation.transformation import ExpandTransformation
from numbers import Number

@library.expansion
class ExpandPure(ExpandTransformation):
"""Implements pure expansion of the Fill library node."""

environments = []

@staticmethod
def expansion(node, parent_state, parent_sdfg):
output = None
for e in parent_state.out_edges(node):
if e.src_conn == "_output":
output = parent_sdfg.arrays[e.data.data]
sdfg = dace.SDFG(f"{node.label}_sdfg")
_, out_arr = sdfg.add_array(
"_output",
output.shape,
output.dtype,
output.storage,
strides=output.strides,
)

state = sdfg.add_state(f"{node.label}_state")
map_params = [f"__i{i}" for i in range(len(out_arr.shape))]
map_rng = {i: f"0:{s}" for i, s in zip(map_params, out_arr.shape)}
out_mem = dace.Memlet(expr=f"_output[{','.join(map_params)}]")
inputs = {}
outputs = {"_out": out_mem}
code = f"_out = {node.value}"
state.add_mapped_tasklet(
f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True
)

return sdfg


@library.node
class Fill(nodes.LibraryNode):
"""Implements filling data containers with a single value"""

implementations = {"pure": ExpandPure}
default_implementation = "pure"
value = properties.SymbolicProperty(
dtype=Number, default=0, desc="value to fill data container"
)

def __init__(self, name, value=0):
super().__init__(name, outputs={"_output"})
self.value = value
self.name = name
2 changes: 1 addition & 1 deletion dace/viewer/webclient
Submodule webclient updated 264 files
57 changes: 57 additions & 0 deletions tests/library/fill_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
from dace.memlet import Memlet
from dace.libraries.standard.nodes import fill


def pure_graph(implementation, dtype, size):
sdfg_name = f"fill_{implementation}_{dtype.ctype}_w{size}"
sdfg = dace.SDFG(sdfg_name)

state = sdfg.add_state("fill")

value = dace.symbol("value")
sdfg.add_array("r", [size], dtype)
result = state.add_write("r")

fill_node = fill.Fill("fill")
fill_node.implementation = implementation
fill_node.value = value

# how to initialize memlet here?
state.add_memlet_path(fill_node, result, src_conn="_output", memlet=Memlet())

return sdfg


def run_test(target, size, value):
if target == "pure":
sdfg = pure_graph("pure", dace.float32, size)
# expand the nested sdfg returned by fill node
sdfg.expand_library_nodes()
else:
print(f"Unsupported target: {target}")
exit(-1)

# we get the function we can call
fill = sdfg.compile()

# supposed to be filled
result = np.ndarray(size, dtype=np.float32)

# the parameters are all the symbols defined in the sdfg
fill(value=value, r=result)
for val in result:
if val != value:
raise ValueError(f"expected {value}, found {val}")
return sdfg


def test_fill_pure():
# should not return a value error
assert isinstance(run_test("pure", 64, 1), dace.SDFG)


if __name__ == "__main__":
test_fill_pure()