From 354e7de55c8b51cc3f7bf10490e604470c29b40c Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 5 Dec 2022 18:42:12 +0800 Subject: [PATCH] [bug] MatrixType bug fix: Add additional restrictions for unpacking a Matrix (#6795) Issue: https://github.com/taichi-dev/taichi/issues/5819 ### Brief Summary --- python/taichi/lang/ast/ast_transformer.py | 4 ++++ tests/python/test_tuple_assign.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index d8af51533e0d3..8bcbfacbbdd5d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -153,6 +153,10 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): # Unpack: a, b, c = ti.Vector([1., 2., 3.]) if isinstance(values, impl.Expr) and values.ptr.is_tensor(): + if len(values.get_shape()) > 1: + raise ValueError( + 'Matrices with more than one columns cannot be unpacked') + values = ctx.ast_builder.expand_expr([values.ptr]) if len(values) == 1: values = values[0] diff --git a/tests/python/test_tuple_assign.py b/tests/python/test_tuple_assign.py index a4dc7e317d3c2..83e547887b57a 100644 --- a/tests/python/test_tuple_assign.py +++ b/tests/python/test_tuple_assign.py @@ -207,8 +207,7 @@ def func(): func() -@test_utils.test(arch=get_host_arch_list()) -def test_unpack_mismatch_matrix(): +def _test_unpack_mismatch_matrix(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) c = ti.field(ti.f32, ()) @@ -223,6 +222,18 @@ def func(): func() +@test_utils.test(arch=get_host_arch_list()) +def test_unpack_mismatch_matrix(): + _test_unpack_mismatch_matrix() + + +@test_utils.test(arch=get_host_arch_list(), + real_matrix=True, + real_matrix_scalarize=True) +def test_unpack_mismatch_matrix_scalarize(): + _test_unpack_mismatch_matrix() + + @test_utils.test(arch=get_host_arch_list()) def test_unpack_from_shape(): a = ti.field(ti.f32, ())