Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Aot] Deprecate element shape and field dim for AOT symbolic args #7100

Merged
merged 13 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions c_api/tests/c_api_cgraph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,34 @@ void graph_aot_test(TiArch arch) {
ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str().c_str());
ti::ComputeGraph run_graph = aot_mod.get_compute_graph("run_graph");

ti::NdArray<int32_t> arr_array =
ti::NdArray<int32_t> arr_array_0 =
runtime.allocate_ndarray<int32_t>({kArrLen}, {}, true);
ti::NdArray<int32_t> arr_array_1 =
runtime.allocate_ndarray<int32_t>({kArrLen}, {1}, true);

run_graph["base0"] = base0_val;
run_graph["base1"] = base1_val;
run_graph["base2"] = base2_val;
run_graph["arr"] = arr_array;
run_graph["arr0"] = arr_array_0;
run_graph["arr1"] = arr_array_1;
run_graph.launch();
runtime.wait();

// Check Results
auto *data = reinterpret_cast<int32_t *>(arr_array.map());
auto *data = reinterpret_cast<int32_t *>(arr_array_0.map());

for (int i = 0; i < kArrLen; i++) {
EXPECT_EQ(data[i], 3 * i + base0_val + base1_val + base2_val);
}
arr_array.unmap();

data = reinterpret_cast<int32_t *>(arr_array_1.map());

for (int i = 0; i < kArrLen; i++) {
EXPECT_EQ(data[i], 3 * i + base0_val + base1_val + base2_val);
}

arr_array_0.unmap();
arr_array_1.unmap();
}

void texture_aot_test(TiArch arch) {
Expand Down
29 changes: 11 additions & 18 deletions python/taichi/examples/graph/mpm88_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,25 @@ def main():
# Build graph
sym_x = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'x',
ti.f32,
field_dim=1,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=1)
sym_v = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'v',
ti.f32,
field_dim=1,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=1)
sym_C = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'C',
ti.f32,
field_dim=1,
element_shape=(2, 2))
sym_J = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'J',
ti.f32,
field_dim=1)
dtype=ti.math.mat2,
ndim=1)
sym_J = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'J', ti.f32, ndim=1)
sym_grid_v = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'grid_v',
ti.f32,
field_dim=2,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=2)
sym_grid_m = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'grid_m',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
g_init_builder = ti.graph.GraphBuilder()
g_init_builder.dispatch(init_particles, sym_x, sym_v, sym_J)

Expand Down
36 changes: 16 additions & 20 deletions python/taichi/examples/graph/stable_fluid_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,40 +240,36 @@ def main():
print('running in graph mode')
velocities_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'velocities_pair_cur',
ti.f32,
field_dim=2,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=2)
velocities_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'velocities_pair_nxt',
ti.f32,
field_dim=2,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=2)
dyes_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'dyes_pair_cur',
ti.f32,
field_dim=2,
element_shape=(3, ))
dtype=ti.math.vec3,
ndim=2)
dyes_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'dyes_pair_nxt',
ti.f32,
field_dim=2,
element_shape=(3, ))
dtype=ti.math.vec3,
ndim=2)
pressures_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'pressures_pair_cur',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
pressures_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'pressures_pair_nxt',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
velocity_divs = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'velocity_divs',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
mouse_data = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'mouse_data',
ti.f32,
field_dim=1)
dtype=ti.f32,
ndim=1)

g1_builder = ti.graph.GraphBuilder()
g1_builder.dispatch(advect, velocities_pair_cur, velocities_pair_cur,
Expand Down
5 changes: 2 additions & 3 deletions python/taichi/examples/graph/texture_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def main():
_t = ti.graph.Arg(ti.graph.ArgKind.SCALAR, 't', ti.f32)
_pixels_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'pixels_arr',
ti.f32,
field_dim=2,
element_shape=(4, ))
dtype=ti.math.vec4,
ndim=2)

_rw_tex = ti.graph.Arg(ti.graph.ArgKind.RWTEXTURE,
'rw_tex',
Expand Down
46 changes: 41 additions & 5 deletions python/taichi/graph/_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from taichi._lib import core as _ti_core
from taichi.aot.utils import produce_injected_args
from taichi.lang import kernel_impl
Expand Down Expand Up @@ -97,12 +99,46 @@ def run(self, args):
def Arg(tag,
name,
dtype=None,
field_dim=0,
ndim=0,
field_dim=None,
element_shape=(),
channel_format=None,
shape=(),
num_channels=None):
if isinstance(dtype, MatrixType):
if field_dim is not None:
if ndim != 0:
raise TaichiRuntimeError(
'field_dim is deprecated, please do not specify field_dim and ndim at the same time.'
)
warnings.warn(
"The field_dim argument for ndarray will be deprecated in v1.5.0, use ndim instead.",
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning)
ndim = field_dim

if tag == ArgKind.SCALAR:
# The scalar tag should never work with array-like parameters
if ndim > 0 or isinstance(dtype, MatrixType) or len(element_shape) > 0:
raise TaichiRuntimeError(
f'Illegal Arg parameter (dtype={dtype}, ndim={ndim}, element_shape={element_shape}) for Scalar tag.'
)
return _ti_core.Arg(tag, name, dtype, ndim, element_shape)

if tag == ArgKind.NDARRAY:
# Ndarray with matrix data type
if isinstance(dtype, MatrixType):
return _ti_core.Arg(tag, name, dtype.dtype, ndim,
dtype.get_shape())
# Ndarray with scalar data type
if len(element_shape) > 0:
warnings.warn(
"The element_shape argument for ndarray will be deprecated in v1.5.0, use vector or matrix data type instead.",
DeprecationWarning)
return _ti_core.Arg(tag, name, dtype, ndim, element_shape)

if tag == ArgKind.MATRIX:
if not isinstance(dtype, MatrixType):
raise TaichiRuntimeError(
f'Tag {tag} must specify matrix data type, but got {dtype}.')
if len(element_shape) > 0:
raise TaichiRuntimeError(
f'Element shape for MatrixType argument "{name}" is not supported.'
Expand All @@ -114,8 +150,8 @@ def Arg(tag,
arg_sublist = []
for _ in range(mat_type.m):
arg_sublist.append(
_ti_core.Arg(tag, f'{name}_mat_arg_{i}', dtype.dtype,
field_dim, element_shape))
_ti_core.Arg(tag, f'{name}_mat_arg_{i}', dtype.dtype, ndim,
element_shape))
i += 1
arg_list.append(arg_sublist)
return arg_list
Expand All @@ -130,7 +166,7 @@ def Arg(tag,
channel_format=channel_format,
num_channels=num_channels,
shape=shape)
return _ti_core.Arg(tag, name, dtype, field_dim, element_shape)
raise TaichiRuntimeError(f'Unknowm tag {tag} for graph Arg {name}.')


__all__ = ['GraphBuilder', 'Graph', 'Arg', 'ArgKind']
18 changes: 13 additions & 5 deletions tests/cpp/aot/gfx_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,18 @@ void run_cgraph1(Arch arch, taichi::lang::Device *device_) {
alloc_params.host_read = true;
alloc_params.size = size * sizeof(int);
alloc_params.usage = taichi::lang::AllocUsage::Storage;
DeviceAllocation devalloc_arr_ = device_->allocate_memory(alloc_params);
Ndarray arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}, {1});
DeviceAllocation devalloc_arr_0 = device_->allocate_memory(alloc_params);
DeviceAllocation devalloc_arr_1 = device_->allocate_memory(alloc_params);
Ndarray arr0 = Ndarray(devalloc_arr_0, PrimitiveType::i32, {size});
Ndarray arr1 = Ndarray(devalloc_arr_1, PrimitiveType::i32, {size}, {1});

int base0 = 10;
int base1 = 20;
int base2 = 30;

std::unordered_map<std::string, taichi::lang::aot::IValue> args;
args.insert({"arr", taichi::lang::aot::IValue::create(arr)});
args.insert({"arr0", taichi::lang::aot::IValue::create(arr0)});
args.insert({"arr1", taichi::lang::aot::IValue::create(arr1)});
args.insert({"base0", taichi::lang::aot::IValue::create(base0)});
args.insert({"base1", taichi::lang::aot::IValue::create(base1)});
args.insert({"base2", taichi::lang::aot::IValue::create(base2)});
Expand All @@ -298,13 +301,18 @@ void run_cgraph1(Arch arch, taichi::lang::Device *device_) {
gfx_runtime->synchronize();

int dst[size] = {0};
load_devalloc(devalloc_arr_, dst, sizeof(dst));
load_devalloc(devalloc_arr_0, dst, sizeof(dst));
for (int i = 0; i < size; i++) {
EXPECT_EQ(dst[i], 3 * i + base0 + base1 + base2);
}
load_devalloc(devalloc_arr_1, dst, sizeof(dst));
for (int i = 0; i < size; i++) {
EXPECT_EQ(dst[i], 3 * i + base0 + base1 + base2);
}

// Deallocate
device_->dealloc_memory(devalloc_arr_);
device_->dealloc_memory(devalloc_arr_0);
device_->dealloc_memory(devalloc_arr_1);
}

void run_cgraph2(Arch arch, taichi::lang::Device *device_) {
Expand Down
61 changes: 46 additions & 15 deletions tests/cpp/aot/llvm/graph_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,45 @@ TEST(LlvmCGraph, RunGraphCpu) {

constexpr int ArrLength = 100;
constexpr int kArrBytes_arr = ArrLength * 1 * sizeof(int32_t);
auto devalloc_arr =
auto devalloc_arr_0 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);
auto devalloc_arr_1 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);

/* Test with Graph */
// Prepare & Run "init" Graph
auto run_graph = mod->get_graph("run_graph");

auto arr = taichi::lang::Ndarray(
devalloc_arr, taichi::lang::PrimitiveType::i32, {ArrLength}, {1});
auto arr0 = taichi::lang::Ndarray(
devalloc_arr_0, taichi::lang::PrimitiveType::i32, {ArrLength});
auto arr1 = taichi::lang::Ndarray(
devalloc_arr_1, taichi::lang::PrimitiveType::i32, {ArrLength},
{
1,
});

int base0 = 10;
int base1 = 20;
int base2 = 30;
std::unordered_map<std::string, taichi::lang::aot::IValue> args;
args.insert({"arr", taichi::lang::aot::IValue::create(arr)});
args.insert({"arr0", taichi::lang::aot::IValue::create(arr0)});
args.insert({"arr1", taichi::lang::aot::IValue::create(arr1)});
args.insert({"base0", taichi::lang::aot::IValue::create(base0)});
args.insert({"base1", taichi::lang::aot::IValue::create(base1)});
args.insert({"base2", taichi::lang::aot::IValue::create(base2)});

run_graph->run(args);
exec.synchronize();

auto *data = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr));
auto *data_0 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_0));
auto *data_1 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_1));
for (int i = 0; i < ArrLength; i++) {
EXPECT_EQ(data_0[i], 3 * i + base0 + base1 + base2);
}
for (int i = 0; i < ArrLength; i++) {
EXPECT_EQ(data[i], 3 * i + base0 + base1 + base2);
EXPECT_EQ(data_1[i], 3 * i + base0 + base1 + base2);
}
}

Expand Down Expand Up @@ -99,34 +112,52 @@ TEST(LlvmCGraph, RunGraphCuda) {

constexpr int ArrLength = 100;
constexpr int kArrBytes_arr = ArrLength * 1 * sizeof(int32_t);
auto devalloc_arr =
auto devalloc_arr_0 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);

auto devalloc_arr_1 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);

/* Test with Graph */
// Prepare & Run "init" Graph
auto run_graph = mod->get_graph("run_graph");

auto arr = taichi::lang::Ndarray(
devalloc_arr, taichi::lang::PrimitiveType::i32, {ArrLength}, {1});
auto arr0 = taichi::lang::Ndarray(
devalloc_arr_0, taichi::lang::PrimitiveType::i32, {ArrLength});

auto arr1 = taichi::lang::Ndarray(
devalloc_arr_1, taichi::lang::PrimitiveType::i32, {ArrLength}, {1});

int base0 = 10;
int base1 = 20;
int base2 = 30;
std::unordered_map<std::string, taichi::lang::aot::IValue> args;
args.insert({"arr", taichi::lang::aot::IValue::create(arr)});
args.insert({"arr0", taichi::lang::aot::IValue::create(arr0)});
args.insert({"arr1", taichi::lang::aot::IValue::create(arr1)});
args.insert({"base0", taichi::lang::aot::IValue::create(base0)});
args.insert({"base1", taichi::lang::aot::IValue::create(base1)});
args.insert({"base2", taichi::lang::aot::IValue::create(base2)});

run_graph->run(args);
exec.synchronize();

auto *data = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr));

std::vector<int32_t> cpu_data(ArrLength);

auto *data_0 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_0));

CUDADriver::get_instance().memcpy_device_to_host(
(void *)cpu_data.data(), (void *)data_0, ArrLength * sizeof(int32_t));

for (int i = 0; i < ArrLength; ++i) {
EXPECT_EQ(cpu_data[i], 3 * i + base0 + base1 + base2);
}

auto *data_1 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_1));

CUDADriver::get_instance().memcpy_device_to_host(
(void *)cpu_data.data(), (void *)data, ArrLength * sizeof(int32_t));
(void *)cpu_data.data(), (void *)data_1, ArrLength * sizeof(int32_t));

for (int i = 0; i < ArrLength; ++i) {
EXPECT_EQ(cpu_data[i], 3 * i + base0 + base1 + base2);
Expand Down
6 changes: 1 addition & 5 deletions tests/cpp/aot/python_scripts/comet_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@
count = ti.field(ti.i32, ())
img = ti.field(ti.f32, (res, res))

sym_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'arr',
ti.f32,
field_dim=3,
element_shape=())
sym_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'arr', dtype=ti.f32, ndim=3)
img_c = 4


Expand Down
Loading