Skip to content

Commit

Permalink
fix(nyz): fix to_item compatibility bug (#646)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Apr 18, 2023
1 parent dd00ebf commit 4c182f6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
15 changes: 13 additions & 2 deletions ding/torch_utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,14 @@ def tensor_to_list(item):
raise TypeError("not support item type: {}".format(type(item)))


def to_item(data):
def to_item(data: Any, ignore_error: bool = True) -> Any:
"""
Overview:
Transform data into python native scalar (i.e. data item), keep other data types unchanged.
Arguments:
- data (:obj:`Any`): The data that needs to be transformed.
- ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \
say, only the data can be transformed into a python native scalar will be returned.
Returns:
- data (:obj:`Any`): Transformed data.
"""
Expand All @@ -300,7 +302,16 @@ def to_item(data):
elif isinstance(data, list) or isinstance(data, tuple):
return [to_item(d) for d in data]
elif isinstance(data, dict):
return {k: to_item(v) for k, v in data.items()}
new_data = {}
for k, v in data.items():
if ignore_error:
try:
new_data[k] = to_item(v)
except ValueError:
pass
else:
new_data[k] = to_item(v)
return new_data
else:
raise TypeError("not support data type: {}".format(data))

Expand Down
5 changes: 4 additions & 1 deletion ding/torch_utils/tests/test_data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def test_to_item(self):
assert np.isscalar(new_data.a)

with pytest.raises(ValueError):
to_item(torch.randn(4))
to_item({'a': torch.randn(4), 'b': torch.rand(1)}, ignore_error=False)
output = to_item({'a': torch.randn(4), 'b': torch.rand(1)}, ignore_error=True)
assert 'a' not in output
assert 'b' in output

def test_same_shape(self):
tlist = [torch.randn(3, 5) for i in range(5)]
Expand Down

0 comments on commit 4c182f6

Please sign in to comment.