From 86ecb00e540d73bc8c6d02858b47c243e8db1c2f Mon Sep 17 00:00:00 2001 From: Jiajie Li Date: Wed, 10 Nov 2021 02:03:47 -0500 Subject: [PATCH 1/4] [Lang] User-friendly exception when copying between ti.field --- python/taichi/lang/field.py | 8 ++++++-- tests/python/test_field.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index 19f6e0ab0b939..f1a3fb601840f 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -1,3 +1,4 @@ +from typing import Type import taichi.lang from taichi.core.util import ti_core as _ti_core from taichi.lang.util import python_scope, to_numpy_type, to_pytorch_type @@ -156,8 +157,11 @@ def copy_from(self, other): Args: other (Field): The source field. """ - assert isinstance(other, Field) - assert len(self.shape) == len(other.shape) + if not isinstance(other, Field): + raise TypeError('Cannot copy from a non-field object') + if self.shape != other.shape: + raise ValueError(f"ti.field shape {self.shape} does not match" + f" the source field shape {other.shape}") taichi.lang.meta.tensor_to_tensor(self, other) @python_scope diff --git a/tests/python/test_field.py b/tests/python/test_field.py index c361657dcfdcf..a936de7b850ee 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -146,3 +146,33 @@ def test_field_name(): for i in range(10): d.append(ti.field(dtype=ti.f32, shape=(2, 3), name=f'd{i}')) assert d[i].name == f'd{i}' + +@ti.test() +@pytest.mark.parametrize('shape', field_shapes) +def test_field_copy_from(shape): + shapes = [ti.i32, ti.f32] # Metal kernel only supports <= 32-bit data + x = ti.field(dtype=ti.f32, shape=shape) + for other_dtype in shapes: + other = ti.field(dtype=other_dtype, shape=shape) + other.fill(1) + x.copy_from(other) + convert = lambda arr: arr[0] if len(arr)==1 else arr + assert(convert(x.shape) == shape) + assert(x.dtype == ti.f32) + assert((x.to_numpy() == 1).all()) + +@ti.test() +def test_field_copy_from_with_mismatch_shape(): + x = ti.field(dtype=ti.f32, shape=(2, 3)) + for other_shape in [(2,), (2, 2), (2, 3, 4)]: + other = ti.field(dtype=ti.f16, shape=other_shape) + with pytest.raises(ValueError): + x.copy_from(other) + +@ti.test() +def test_field_copy_from_with_non_filed_object(): + import numpy as np + x = ti.field(dtype=ti.f32, shape=(2, 3)) + other = np.zeros((2, 3)) + with pytest.raises(TypeError): + x.copy_from(other) \ No newline at end of file From 39807e1d68ca1102a1c1310845e9d15aeb6bc5cf Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Wed, 10 Nov 2021 07:32:06 +0000 Subject: [PATCH 2/4] Auto Format --- python/taichi/lang/field.py | 1 + tests/python/test_field.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index f1a3fb601840f..974bf63301fdc 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -1,4 +1,5 @@ from typing import Type + import taichi.lang from taichi.core.util import ti_core as _ti_core from taichi.lang.util import python_scope, to_numpy_type, to_pytorch_type diff --git a/tests/python/test_field.py b/tests/python/test_field.py index a936de7b850ee..a37875f2ebe3f 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -147,32 +147,35 @@ def test_field_name(): d.append(ti.field(dtype=ti.f32, shape=(2, 3), name=f'd{i}')) assert d[i].name == f'd{i}' + @ti.test() @pytest.mark.parametrize('shape', field_shapes) def test_field_copy_from(shape): - shapes = [ti.i32, ti.f32] # Metal kernel only supports <= 32-bit data + shapes = [ti.i32, ti.f32] # Metal kernel only supports <= 32-bit data x = ti.field(dtype=ti.f32, shape=shape) for other_dtype in shapes: other = ti.field(dtype=other_dtype, shape=shape) other.fill(1) x.copy_from(other) - convert = lambda arr: arr[0] if len(arr)==1 else arr - assert(convert(x.shape) == shape) - assert(x.dtype == ti.f32) - assert((x.to_numpy() == 1).all()) + convert = lambda arr: arr[0] if len(arr) == 1 else arr + assert (convert(x.shape) == shape) + assert (x.dtype == ti.f32) + assert ((x.to_numpy() == 1).all()) + @ti.test() def test_field_copy_from_with_mismatch_shape(): x = ti.field(dtype=ti.f32, shape=(2, 3)) - for other_shape in [(2,), (2, 2), (2, 3, 4)]: - other = ti.field(dtype=ti.f16, shape=other_shape) + for other_shape in [(2, ), (2, 2), (2, 3, 4)]: + other = ti.field(dtype=ti.f16, shape=other_shape) with pytest.raises(ValueError): x.copy_from(other) - + + @ti.test() def test_field_copy_from_with_non_filed_object(): import numpy as np x = ti.field(dtype=ti.f32, shape=(2, 3)) other = np.zeros((2, 3)) with pytest.raises(TypeError): - x.copy_from(other) \ No newline at end of file + x.copy_from(other) From 1ff10ae689f52a3fd5f3d37261f39e3389be8971 Mon Sep 17 00:00:00 2001 From: Jiajie Li Date: Wed, 10 Nov 2021 16:44:58 -0500 Subject: [PATCH 3/4] fix: removed unused import --- python/taichi/lang/field.py | 1 - tests/python/test_field.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index f1a3fb601840f..d2d1a0f7416d6 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -1,4 +1,3 @@ -from typing import Type import taichi.lang from taichi.core.util import ti_core as _ti_core from taichi.lang.util import python_scope, to_numpy_type, to_pytorch_type diff --git a/tests/python/test_field.py b/tests/python/test_field.py index a936de7b850ee..f9360ed0c978f 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -175,4 +175,4 @@ def test_field_copy_from_with_non_filed_object(): x = ti.field(dtype=ti.f32, shape=(2, 3)) other = np.zeros((2, 3)) with pytest.raises(TypeError): - x.copy_from(other) \ No newline at end of file + x.copy_from(other) From 97b782606b769896493330ad04c03a74b050ecf5 Mon Sep 17 00:00:00 2001 From: Jiajie Li Date: Sun, 19 Dec 2021 00:53:08 -0500 Subject: [PATCH 4/4] [fix] Use pytest way to test different dtypes --- tests/python/test_field.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/python/test_field.py b/tests/python/test_field.py index a37875f2ebe3f..7013e47d65b13 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -150,17 +150,16 @@ def test_field_name(): @ti.test() @pytest.mark.parametrize('shape', field_shapes) -def test_field_copy_from(shape): - shapes = [ti.i32, ti.f32] # Metal kernel only supports <= 32-bit data +@pytest.mark.parametrize('dtype', [ti.i32, ti.f32]) +def test_field_copy_from(shape, dtype): x = ti.field(dtype=ti.f32, shape=shape) - for other_dtype in shapes: - other = ti.field(dtype=other_dtype, shape=shape) - other.fill(1) - x.copy_from(other) - convert = lambda arr: arr[0] if len(arr) == 1 else arr - assert (convert(x.shape) == shape) - assert (x.dtype == ti.f32) - assert ((x.to_numpy() == 1).all()) + other = ti.field(dtype=dtype, shape=shape) + other.fill(1) + x.copy_from(other) + convert = lambda arr: arr[0] if len(arr) == 1 else arr + assert (convert(x.shape) == shape) + assert (x.dtype == ti.f32) + assert ((x.to_numpy() == 1).all()) @ti.test()