From 4c182f6114a260390f5712ae43b8435768ec41f0 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 18 Apr 2023 13:33:05 +0800 Subject: [PATCH] fix(nyz): fix to_item compatibility bug (#646) --- ding/torch_utils/data_helper.py | 15 +++++++++++++-- ding/torch_utils/tests/test_data_helper.py | 5 ++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 2411069b14..a5804a5f4b 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -280,12 +280,14 @@ def tensor_to_list(item): raise TypeError("not support item type: {}".format(type(item))) -def to_item(data): +def to_item(data: Any, ignore_error: bool = True) -> Any: """ Overview: Transform data into python native scalar (i.e. data item), keep other data types unchanged. Arguments: - data (:obj:`Any`): The data that needs to be transformed. + - ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \ + say, only the data can be transformed into a python native scalar will be returned. Returns: - data (:obj:`Any`): Transformed data. """ @@ -300,7 +302,16 @@ def to_item(data): elif isinstance(data, list) or isinstance(data, tuple): return [to_item(d) for d in data] elif isinstance(data, dict): - return {k: to_item(v) for k, v in data.items()} + new_data = {} + for k, v in data.items(): + if ignore_error: + try: + new_data[k] = to_item(v) + except ValueError: + pass + else: + new_data[k] = to_item(v) + return new_data else: raise TypeError("not support data type: {}".format(data)) diff --git a/ding/torch_utils/tests/test_data_helper.py b/ding/torch_utils/tests/test_data_helper.py index ef0964fa4c..840e99a414 100644 --- a/ding/torch_utils/tests/test_data_helper.py +++ b/ding/torch_utils/tests/test_data_helper.py @@ -139,7 +139,10 @@ def test_to_item(self): assert np.isscalar(new_data.a) with pytest.raises(ValueError): - to_item(torch.randn(4)) + to_item({'a': torch.randn(4), 'b': torch.rand(1)}, ignore_error=False) + output = to_item({'a': torch.randn(4), 'b': torch.rand(1)}, ignore_error=True) + assert 'a' not in output + assert 'b' in output def test_same_shape(self): tlist = [torch.randn(3, 5) for i in range(5)]