From dc135e921d7debc07f7f091027997547fe3780cd Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Fri, 25 Nov 2022 14:21:39 +0800 Subject: [PATCH 1/3] [bug] Fix name collision in ti.dataclass --- python/taichi/lang/struct.py | 4 ++-- tests/python/test_custom_struct.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 58c9b370ec22b..192518dca5675 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -97,7 +97,7 @@ def items(self): def _register_members(self): for k in self.keys: - setattr(Struct, k, + setattr(self, k, property( Struct._make_getter(k), Struct._make_setter(k), @@ -769,7 +769,7 @@ def dataclass(cls): and methods from the class attached. """ # save the annotation fields for the struct - fields = cls.__annotations__ + fields = getattr(cls, '__annotations__', {}) # get the class methods to be attached to the struct types fields['__struct_methods'] = { attribute: getattr(cls, attribute) diff --git a/tests/python/test_custom_struct.py b/tests/python/test_custom_struct.py index f6dfa22e26576..094b2987dafa6 100644 --- a/tests/python/test_custom_struct.py +++ b/tests/python/test_custom_struct.py @@ -1,5 +1,6 @@ import numpy as np from pytest import approx +from taichi.lang.misc import get_host_arch_list import taichi as ti from tests import test_utils @@ -443,3 +444,20 @@ def test(): assert A.mass == 2.0 test() + + +@test_utils.test(arch=get_host_arch_list()) +def test_name_collision(): + # https://github.com/taichi-dev/taichi/issues/6652 + @ti.dataclass + class Foo: + zoo: ti.f32 + + @ti.dataclass + class Bar: + @ti.func + def zoo(self): + return 0 + + Foo() # instantiate struct with zoo as member first + Bar() # then instantiate struct with zoo as method From 1d61357521eeca2b525b290e146837e8552e2a1b Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Fri, 2 Dec 2022 17:40:53 +0800 Subject: [PATCH 2/3] Fix --- python/taichi/lang/struct.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 192518dca5675..027258121fa89 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -42,6 +42,7 @@ class Struct(TaichiOperations): dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})]) """ _is_taichi_class = True + _instance_count = 0 def __init__(self, *args, **kwargs): # converts lists to matrices and dicts to structs @@ -96,12 +97,12 @@ def items(self): return self.entries.items() def _register_members(self): - for k in self.keys: - setattr(self, k, - property( - Struct._make_getter(k), - Struct._make_setter(k), - )) + # https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance + cls = self.__class__ + new_cls_name = cls.__name__ + str(cls._instance_count) + cls._instance_count += 1 + properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys} + self.__class__ = type(new_cls_name, (cls, ), properties) def _register_methods(self): for name, method in self.methods.items(): From 7c1054e93e1203dd7694fcc721273658980f512d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Dec 2022 09:42:34 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/struct.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 027258121fa89..1aa8344a149df 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -101,7 +101,10 @@ def _register_members(self): cls = self.__class__ new_cls_name = cls.__name__ + str(cls._instance_count) cls._instance_count += 1 - properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys} + properties = { + k: property(cls._make_getter(k), cls._make_setter(k)) + for k in self.keys + } self.__class__ = type(new_cls_name, (cls, ), properties) def _register_methods(self):