diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index 4c636b1b4aec7..398e997583cd4 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -156,8 +156,11 @@ def copy_from(self, other): Args: other (Field): The source field. """ - assert isinstance(other, Field) - assert len(self.shape) == len(other.shape) + if not isinstance(other, Field): + raise TypeError('Cannot copy from a non-field object') + if self.shape != other.shape: + raise ValueError(f"ti.field shape {self.shape} does not match" + f" the source field shape {other.shape}") taichi.lang.meta.tensor_to_tensor(self, other) @python_scope diff --git a/tests/python/test_field.py b/tests/python/test_field.py index c361657dcfdcf..7013e47d65b13 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -146,3 +146,35 @@ def test_field_name(): for i in range(10): d.append(ti.field(dtype=ti.f32, shape=(2, 3), name=f'd{i}')) assert d[i].name == f'd{i}' + + +@ti.test() +@pytest.mark.parametrize('shape', field_shapes) +@pytest.mark.parametrize('dtype', [ti.i32, ti.f32]) +def test_field_copy_from(shape, dtype): + x = ti.field(dtype=ti.f32, shape=shape) + other = ti.field(dtype=dtype, shape=shape) + other.fill(1) + x.copy_from(other) + convert = lambda arr: arr[0] if len(arr) == 1 else arr + assert (convert(x.shape) == shape) + assert (x.dtype == ti.f32) + assert ((x.to_numpy() == 1).all()) + + +@ti.test() +def test_field_copy_from_with_mismatch_shape(): + x = ti.field(dtype=ti.f32, shape=(2, 3)) + for other_shape in [(2, ), (2, 2), (2, 3, 4)]: + other = ti.field(dtype=ti.f16, shape=other_shape) + with pytest.raises(ValueError): + x.copy_from(other) + + +@ti.test() +def test_field_copy_from_with_non_filed_object(): + import numpy as np + x = ti.field(dtype=ti.f32, shape=(2, 3)) + other = np.zeros((2, 3)) + with pytest.raises(TypeError): + x.copy_from(other)