From e1cfc45915be490c9164573fb881f3f2e9ad88fd Mon Sep 17 00:00:00 2001 From: Freddy Heppell Date: Tue, 13 Dec 2022 23:02:36 +0000 Subject: [PATCH 1/3] Raise error if ClassLabel names is not python list --- src/datasets/features/features.py | 2 ++ tests/features/test_features.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 6fd1353ccc3..35e596c815f 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -944,6 +944,8 @@ def __post_init__(self, num_classes, names_file): self.names = [str(i) for i in range(self.num_classes)] else: raise ValueError("Please provide either num_classes, names or names_file.") + elif not isinstance(self.names, list): + raise ValueError(f"Please provide names as a list, is {type(self.names)}") # Set self.num_classes if self.num_classes is None: self.num_classes = len(self.names) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index e0803949032..48ce1061e36 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -287,6 +287,8 @@ def test_classlabel_init(tmp_path_factory): classlabel = ClassLabel(names=names, names_file=names_file) with pytest.raises(ValueError): classlabel = ClassLabel() + with pytest.raises(ValueError): + classlabel = ClassLabel(names=np.array(names)) def test_classlabel_str2int(): From 17c8b8d57dafab9b5b00b0d93f8e8396d5be4e1b Mon Sep 17 00:00:00 2001 From: Freddy Heppell Date: Wed, 14 Dec 2022 12:22:33 +0000 Subject: [PATCH 2/3] Change to accepting Sequence for names --- src/datasets/features/features.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 35e596c815f..31ccaf92480 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -19,6 +19,7 @@ import re import sys from collections.abc import Iterable, Mapping +from collections.abc import Sequence as SequenceABC from dataclasses import InitVar, dataclass, field, fields from functools import reduce, wraps from operator import mul @@ -944,7 +945,7 @@ def __post_init__(self, num_classes, names_file): self.names = [str(i) for i in range(self.num_classes)] else: raise ValueError("Please provide either num_classes, names or names_file.") - elif not isinstance(self.names, list): + elif not isinstance(self.names, SequenceABC): raise ValueError(f"Please provide names as a list, is {type(self.names)}") # Set self.num_classes if self.num_classes is None: From f36448832ac8c2ed94dcda4367750b1a1bfe8751 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Wed, 21 Dec 2022 19:35:01 +0100 Subject: [PATCH 3/3] Replace `ValueError` with `TypeError` --- src/datasets/features/features.py | 2 +- tests/features/test_features.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 31ccaf92480..045b3dae39b 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -946,7 +946,7 @@ def __post_init__(self, num_classes, names_file): else: raise ValueError("Please provide either num_classes, names or names_file.") elif not isinstance(self.names, SequenceABC): - raise ValueError(f"Please provide names as a list, is {type(self.names)}") + raise TypeError(f"Please provide names as a list, is {type(self.names)}") # Set self.num_classes if self.num_classes is None: self.num_classes = len(self.names) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 48ce1061e36..d036e3295c7 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -287,7 +287,7 @@ def test_classlabel_init(tmp_path_factory): classlabel = ClassLabel(names=names, names_file=names_file) with pytest.raises(ValueError): classlabel = ClassLabel() - with pytest.raises(ValueError): + with pytest.raises(TypeError): classlabel = ClassLabel(names=np.array(names))