diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 893c9f9a1b0..a76150cd777 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1250,6 +1250,8 @@ def encode_nested_example(schema, obj, level=0): sub_schema = schema[0] if obj is None: return None + elif isinstance(obj, np.ndarray): + return encode_nested_example(schema, obj.tolist()) else: if len(obj) > 0: for first_elmt in obj: diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index 3ac1e86fc41..408f6bc5a12 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -31,8 +31,8 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator): current_example["__key__"] = example_key current_example["__url__"] = tar_path current_example[field_name.lower()] = f.read() - if field_name in cls.DECODERS: - current_example[field_name] = cls.DECODERS[field_name](current_example[field_name]) + if field_name.split(".")[-1] in cls.DECODERS: + current_example[field_name] = cls.DECODERS[field_name.split(".")[-1]](current_example[field_name]) if current_example: yield current_example