Skip to content

Commit

Permalink
ENH: Implement broadcast_to function
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Oct 2, 2024
1 parent df50a8d commit ca55bf7
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
)
from ._ops import (
add,
broadcast_to,
reshape,
)

__all__ = [
"add",
"broadcast_to",
"asarray",
"asdtype",
"reshape",
Expand Down
55 changes: 55 additions & 0 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ def reshape(a, shape):
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


@fn_cache
def get_broadcast_to_module(
in_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
dimensions: tuple[int, ...],
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(in_tensor_type)
def broadcast_to(in_tensor):
out = tensor.empty(out_tensor_type, [])
return linalg.broadcast(in_tensor, outs=[out], dimensions=dimensions)

broadcast_to.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "broadcast_to_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


def add(x1: Tensor, x2: Tensor) -> Tensor:
ret_obj = x1._format_class()
out_tensor_type = x1._obj.get_tensor_definition(x1.shape)
Expand Down Expand Up @@ -152,3 +178,32 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
)

return Tensor(ret_obj, shape=out_tensor_type.shape)


def _infer_format_class(rank: int, values_dtype: type[DType], index_dtype: type[DType]) -> type[ctypes.Structure]:
from ._constructors import get_csf_class, get_csx_class, get_dense_class

if rank == 1:
return get_dense_class(values_dtype, index_dtype)
if rank == 2:
return get_csx_class(values_dtype, index_dtype, order="r")
if rank == 3:
return get_csf_class(values_dtype, index_dtype)
raise Exception(f"Rank not supported to infer format: {rank}")


def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) -> Tensor:
x_tensor_type = x._obj.get_tensor_definition(x.shape)
format_class = _infer_format_class(len(shape), x._values_dtype, x._index_dtype)
out_tensor_type = format_class.get_tensor_definition(shape)
ret_obj = format_class()

broadcast_to_module = get_broadcast_to_module(x_tensor_type, out_tensor_type, tuple(dimensions))

broadcast_to_module.invoke(
"broadcast_to",
ctypes.pointer(ctypes.pointer(ret_obj)),
*x._obj.to_module_arg(),
)

return Tensor(ret_obj, shape=shape)
52 changes: 51 additions & 1 deletion sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,55 @@ def test_reshape(rng, dtype):
np.testing.assert_array_equal(actual, expected)

# DENSE
# NOTE: dense reshape is probably broken in MLIR
# NOTE: dense reshape is probably broken in MLIR in 19.x branch
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)


@parametrize_dtypes
def test_broadcast_to(dtype):
# CSR, CSC, COO
for shape, new_shape, dimensions, input_arr, expected_arrs in [
(
(3, 4),
(2, 3, 4),
[0],
np.array([[0, 1, 0, 3], [0, 0, 4, 5], [6, 7, 0, 0]]),
[
np.array([0, 3, 6]),
np.array([0, 1, 2, 0, 1, 2]),
np.array([0, 2, 4, 6, 8, 10, 12]),
np.array([1, 3, 2, 3, 0, 1, 1, 3, 2, 3, 0, 1]),
np.array([1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0]),
],
),
(
(4, 2),
(4, 2, 2),
[1],
np.array([[0, 1], [0, 0], [2, 3], [4, 0]]),
[
np.array([0, 2, 2, 4, 6]),
np.array([0, 1, 0, 1, 0, 1]),
np.array([0, 1, 2, 4, 6, 7, 8]),
np.array([1, 1, 0, 1, 0, 1, 0, 0]),
np.array([1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 4.0]),
],
),
]:
for fn_format in [sps.csr_array, sps.csc_array, sps.coo_array]:
arr = fn_format(input_arr, shape=shape, dtype=dtype)
arr.sum_duplicates()
tensor = sparse.asarray(arr)
result = sparse.broadcast_to(tensor, new_shape, dimensions=dimensions).to_scipy_sparse()

for actual, expected in zip(result, expected_arrs, strict=False):
np.testing.assert_allclose(actual, expected)

# DENSE
np_arr = np.array([0, 0, 2, 3, 0, 1])
arr = np.asarray(np_arr, dtype=dtype)
tensor = sparse.asarray(arr)
result = sparse.broadcast_to(tensor, (3, 6), dimensions=[0]).to_scipy_sparse()

assert result.format == "csr"
np.testing.assert_allclose(result.todense(), np.repeat(np_arr[np.newaxis], 3, axis=0))

0 comments on commit ca55bf7

Please sign in to comment.