From c1b98ee55064d534899e143d0e3cf92de3c0228e Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 8 Sep 2022 17:41:14 +0100 Subject: [PATCH] Remove Sequence-specific code --- src/datasets/arrow_dataset.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 58a4e96796c..8d0496dc418 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -299,24 +299,17 @@ def _get_output_signature( f"Unrecognized array dtype {np_arrays[0].dtype}. \n" "Nested types and image/audio types are not supported yet." ) - if ( - column in dataset - and isinstance(dataset.features[column], Sequence) - and dataset.features[column].length != -1 - ): - static_shape = [batch_size, dataset.features[column].length] - else: - shapes = [array.shape for array in np_arrays] - static_shape = [] - for dim in range(len(shapes[0])): - sizes = set([shape[dim] for shape in shapes]) - if dim == 0: - static_shape.append(batch_size) - continue - if len(sizes) == 1: # This dimension looks constant - static_shape.append(sizes.pop()) - else: # Use None for variable dimensions - static_shape.append(None) + shapes = [array.shape for array in np_arrays] + static_shape = [] + for dim in range(len(shapes[0])): + sizes = set([shape[dim] for shape in shapes]) + if dim == 0: + static_shape.append(batch_size) + continue + if len(sizes) == 1: # This dimension looks constant + static_shape.append(sizes.pop()) + else: # Use None for variable dimensions + static_shape.append(None) tf_columns_to_signatures[column] = tf.TensorSpec(shape=static_shape, dtype=tf_dtype) np_columns_to_dtypes[column] = np_dtype