Skip to content

Commit

Permalink
[misc] Cherry pick struct return related commits (#7575)
Browse files Browse the repository at this point in the history
Co-authored-by: Lin Jiang <linjiang@taichi.graphics>
  • Loading branch information
turbo0628 and lin-hitonami authored Mar 17, 2023
1 parent 10b062d commit 9fa455c
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/cover-in-ci.lst
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
docs/lang/articles/basic
docs/lang/articles/advanced
docs/lang/articles/kernels
34 changes: 25 additions & 9 deletions docs/lang/articles/kernels/kernel_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,23 @@ print(x) # Prints [5, 7, 9]

### Return value

In Taichi, a kernel can have at most one return value, which can be a scalar, `ti.Matrix`, or `ti.Vector`. Here are the rules to follow when defining the return value of a kernel:
In Taichi, a kernel is allowed to have a maximum of one return value, which could either be a scalar, `ti.Matrix`, or `ti.Vector`.
Moreover, in the LLVM-based backends (CPU and CUDA backends), a return value could also be a `ti.Struct`.

Here is an example of a kernel that returns a ti.Struct:

```python
s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16)
s1 = ti.types.struct(a=ti.f32, b=s0)

@ti.kernel
def foo() -> s1:
return s1(a=1, b=s0(a=ti.math.vec3(100, 0.2, 3), b=1))

print(foo()) # {'a': 1.0, 'b': {'a': [100.0, 0.2, 3.0], 'b': 1}}
```

When defining the return value of a kernel in Taichi, it is important to follow these rules:

- Use type hint to specify the return value of a kernel.
- Make sure that you have at most one return value in a kernel.
Expand Down Expand Up @@ -276,14 +292,14 @@ Return values of a Taichi function can be scalars, `ti.Matrix`, `ti.Vector`, `ti

## A recap: Taichi kernel vs. Taichi function

| | **Kernel** | **Taichi Function** |
| ----------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
| Call scope | Python scope | Taichi scope |
| Type hint arguments | Mandatory | Recommended |
| Type hint return values | Mandatory | Recommended |
| Return type | <ul><li>Scalar</li><li>`ti.Vector`</li><li>`ti.Matrix`</li></ul> | <ul><li>Scalar</li><li>`ti.Vector`</li><li>`ti.Matrix`</li><li>`ti.Struct`</li><li>...</li></ul> |
| Maximum number of elements in arguments | <ul><li>32 (OpenGL)</li><li>64 (otherwise)</li></ul> | Unlimited |
| Maximum number of return values in a return statement | 1 | Unlimited |
| | **Kernel** | **Taichi Function** |
| ----------------------------------------------------- |-------------------------------------------------------------------------------------------------------------------| ------------------------------------------------------------ |
| Call scope | Python scope | Taichi scope |
| Type hint arguments | Mandatory | Recommended |
| Type hint return values | Mandatory | Recommended |
| Return type | <ul><li>Scalar</li><li>`ti.Vector`</li><li>`ti.Matrix`</li><li>`ti.Struct`(Only on LLVM-based backends)</li></ul> | <ul><li>Scalar</li><li>`ti.Vector`</li><li>`ti.Matrix`</li><li>`ti.Struct`</li><li>...</li></ul> |
| Maximum number of elements in arguments | <ul><li>32 (OpenGL)</li><li>64 (otherwise)</li></ul> | Unlimited |
| Maximum number of return values in a return statement | 1 | Unlimited |


## Key terms
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,11 +902,11 @@ def construct_kernel_ret(self, launch_ctx, ret_type, index=()):
]
if isinstance(ret_type, CompoundType):
return ret_type.from_kernel_struct_ret(launch_ctx, index)
if id(ret_type) in primitive_types.integer_type_ids:
if ret_type in primitive_types.integer_types:
if is_signed(cook_dtype(ret_type)):
return launch_ctx.get_struct_ret_int(index)
return launch_ctx.get_struct_ret_uint(index)
if id(ret_type) in primitive_types.real_type_ids:
if ret_type in primitive_types.real_types:
return launch_ctx.get_struct_ret_float(index)
raise TaichiRuntimeTypeError(f"Invalid return type on index={index}")

Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,12 +1511,12 @@ def from_real_func_ret(self, func_ret, ret_index=()):
])

def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
if id(self.dtype) in primitive_types.integer_type_ids:
if self.dtype in primitive_types.integer_types:
if is_signed(cook_dtype(self.dtype)):
get_ret_func = launch_ctx.get_struct_ret_int
else:
get_ret_func = launch_ctx.get_struct_ret_uint
elif id(self.dtype) in primitive_types.real_type_ids:
elif self.dtype in primitive_types.real_types:
get_ret_func = launch_ctx.get_struct_ret_float
else:
raise TaichiRuntimeTypeError(
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,14 +746,14 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
d[name] = dtype.from_kernel_struct_ret(launch_ctx,
ret_index + (index, ))
else:
if id(dtype) in primitive_types.integer_type_ids:
if dtype in primitive_types.integer_types:
if is_signed(cook_dtype(dtype)):
d[name] = launch_ctx.get_struct_ret_int(ret_index +
(index, ))
else:
d[name] = launch_ctx.get_struct_ret_uint(ret_index +
(index, ))
elif id(dtype) in primitive_types.real_type_ids:
elif dtype in primitive_types.real_types:
d[name] = launch_ctx.get_struct_ret_float(ret_index +
(index, ))
else:
Expand Down
18 changes: 18 additions & 0 deletions tests/python/test_return.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pytest import approx

import taichi as ti
from tests import test_utils
Expand Down Expand Up @@ -181,3 +182,20 @@ def foo() -> ti.types.vector(2, ti.u64):
return ti.Vector([ti.u64(2**64 - 1), ti.u64(2**64 - 1)])

assert (foo()[0] == 2**64 - 1)


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_struct_ret_with_matrix():
s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16)
s1 = ti.types.struct(a=ti.f32, b=s0)

@ti.kernel
def foo() -> s1:
return s1(a=1, b=s0(a=ti.math.vec3([100, 0.2, 3]), b=65537))

ret = foo()
assert (ret.a == approx(1))
assert (ret.b.a[0] == approx(100))
assert (ret.b.a[1] == approx(0.2))
assert (ret.b.a[2] == approx(3))
assert (ret.b.b == 1)

0 comments on commit 9fa455c

Please sign in to comment.