Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ArrowWriter closes stream at exit #1971

Merged
merged 17 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions benchmarks/benchmark_array_xd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@

@get_duration
def write(my_features, dummy_data, tmp_dir):
writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow"))
for key, record in dummy_data:
example = my_features.encode_example(record)
writer.write(example)
num_examples, num_bytes = writer.finalize()
with ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) as writer:
for key, record in dummy_data:
example = my_features.encode_example(record)
writer.write(example)
num_examples, num_bytes = writer.finalize()


@get_duration
Expand Down
10 changes: 5 additions & 5 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def generate_examples(features: dict, num_examples=100, seq_shapes=None):
def generate_example_dataset(dataset_path, features, num_examples=100, seq_shapes=None):
dummy_data = generate_examples(features, num_examples=num_examples, seq_shapes=seq_shapes)

writer = datasets.ArrowWriter(features=features, path=dataset_path)
for key, record in dummy_data:
example = features.encode_example(record)
writer.write(example)
with datasets.ArrowWriter(features=features, path=dataset_path) as writer:
for key, record in dummy_data:
example = features.encode_example(record)
writer.write(example)

num_final_examples, num_bytes = writer.finalize()
num_final_examples, num_bytes = writer.finalize()

assert (
num_final_examples == num_examples
Expand Down
110 changes: 56 additions & 54 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,16 +507,16 @@ def save_to_disk(self, dataset_path: str, fs=None):
if self._indices is not None:
if not self._indices_data_files:
cache_file_name = os.path.join(temp_dataset_path, "indices.arrow")
writer = ArrowWriter(path=cache_file_name)
writer.write_table(self._indices)
writer.finalize()
with ArrowWriter(path=cache_file_name) as writer:
writer.write_table(self._indices)
writer.finalize()
self._indices_data_files = [{"filename": cache_file_name}]
# Write dataset if needed
if not self._data_files or any(len(h["transforms"]) > 0 for h in self._inplace_history):
cache_file_name = os.path.join(temp_dataset_path, "dataset.arrow")
writer = ArrowWriter(path=cache_file_name)
writer.write_table(self._data)
writer.finalize()
with ArrowWriter(path=cache_file_name) as writer:
writer.write_table(self._data)
writer.finalize()
self._data_files = [{"filename": cache_file_name}]
self._inplace_history = [{"transforms": []}]
# Copy all files into the dataset directory
Expand Down Expand Up @@ -1551,45 +1551,46 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
disable_nullable=disable_nullable,
)

try:
# Loop over single examples or batches and write to buffer/file if examples are to be updated
pbar_iterable = self if not batched else range(0, len(self), batch_size)
pbar_unit = "ex" if not batched else "ba"
pbar_desc = "#" + str(rank) if rank is not None else None
pbar = tqdm(pbar_iterable, disable=not_verbose, position=rank, unit=pbar_unit, desc=pbar_desc)
if not batched:
for i, example in enumerate(pbar):
example = apply_function_on_filtered_inputs(example, i, offset=offset)
if update_data:
example = cast_to_python_objects(example)
writer.write(example)
else:
for i in pbar:
if drop_last_batch and i + batch_size > self.num_rows:
continue
batch = self[i : i + batch_size]
indices = list(range(*(slice(i, i + batch_size).indices(self.num_rows)))) # Something simpler?
try:
batch = apply_function_on_filtered_inputs(
batch, indices, check_same_num_examples=len(self.list_indexes()) > 0, offset=offset
)
except NumExamplesMismatch:
raise DatasetTransformationNotAllowedError(
"Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it."
)
if update_data:
batch = cast_to_python_objects(batch)
writer.write_batch(batch)
if update_data:
writer.finalize() # close_stream=bool(buf_writer is None)) # We only close if we are writing in a file
except (Exception, KeyboardInterrupt):
if update_data:
writer.finalize()
if update_data and tmp_file is not None:
tmp_file.close()
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
raise
with writer:
try:
# Loop over single examples or batches and write to buffer/file if examples are to be updated
pbar_iterable = self if not batched else range(0, len(self), batch_size)
pbar_unit = "ex" if not batched else "ba"
pbar_desc = "#" + str(rank) if rank is not None else None
pbar = tqdm(pbar_iterable, disable=not_verbose, position=rank, unit=pbar_unit, desc=pbar_desc)
if not batched:
for i, example in enumerate(pbar):
example = apply_function_on_filtered_inputs(example, i, offset=offset)
if update_data:
example = cast_to_python_objects(example)
writer.write(example)
else:
for i in pbar:
if drop_last_batch and i + batch_size > self.num_rows:
continue
batch = self[i : i + batch_size]
indices = list(range(*(slice(i, i + batch_size).indices(self.num_rows)))) # Something simpler?
try:
batch = apply_function_on_filtered_inputs(
batch, indices, check_same_num_examples=len(self.list_indexes()) > 0, offset=offset
)
except NumExamplesMismatch:
raise DatasetTransformationNotAllowedError(
"Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it."
)
if update_data:
batch = cast_to_python_objects(batch)
writer.write_batch(batch)
if update_data:
writer.finalize() # close_stream=bool(buf_writer is None)) # We only close if we are writing in a file
except (Exception, KeyboardInterrupt):
if update_data:
writer.finalize()
if update_data and tmp_file is not None:
tmp_file.close()
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
raise

if update_data and tmp_file is not None:
tmp_file.close()
Expand Down Expand Up @@ -1833,15 +1834,16 @@ def select(

indices_table = pa.Table.from_arrays([indices_array], names=["indices"])

try:
writer.write_table(indices_table)
writer.finalize() # close_stream=bool(buf_writer is None)) # We only close if we are writing in a file
except (Exception, KeyboardInterrupt):
if tmp_file is not None:
tmp_file.close()
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
raise
with writer:
try:
writer.write_table(indices_table)
writer.finalize() # close_stream=bool(buf_writer is None)) We only close if we are writing in a file
except (Exception, KeyboardInterrupt):
if tmp_file is not None:
tmp_file.close()
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
raise

if tmp_file is not None:
tmp_file.close()
Expand Down
38 changes: 28 additions & 10 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __arrow_array__(self, type=None):
raise


class ArrowWriter(object):
class ArrowWriter:
"""Shuffles and writes Examples to Arrow files."""

def __init__(
Expand Down Expand Up @@ -157,8 +157,10 @@ def __init__(
self._path = path
if stream is None:
self.stream = pa.OSFile(self._path, "wb")
self._closable_stream = True
else:
self.stream = stream
self._closable_stream = False

self.fingerprint = fingerprint
self.disable_nullable = disable_nullable
Expand All @@ -176,6 +178,22 @@ def __len__(self):
""" Return the number of writed and staged examples """
return self._num_examples + len(self.current_rows)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def close(self):
# Try closing if opened; if closed: pyarrow.lib.ArrowInvalid: Invalid operation on closed file
if self.pa_writer: # it might be None
try:
self.pa_writer.close()
except Exception: # pyarrow.lib.ArrowInvalid, OSError
pass
if self._closable_stream and not self.stream.closed:
self.stream.close() # This also closes self.pa_writer if it is opened

def _build_writer(self, inferred_schema: pa.Schema):
inferred_features = Features.from_arrow_schema(inferred_schema)
if self._features is not None:
Expand Down Expand Up @@ -416,14 +434,14 @@ def finalize(self, metrics_query_result: dict):
def parquet_to_arrow(sources, destination):
"""Convert parquet files to arrow file. Inputs can be str paths or file-like objects"""
stream = None if isinstance(destination, str) else destination
writer = ArrowWriter(path=destination, stream=stream)
not_verbose = bool(logger.getEffectiveLevel() > WARNING)
for source in tqdm(sources, unit="sources", disable=not_verbose):
pf = pa.parquet.ParquetFile(source)
for i in tqdm(range(pf.num_row_groups), unit="row_groups", leave=False, disable=not_verbose):
df = pf.read_row_group(i).to_pandas()
for col in df.columns:
df[col] = df[col].apply(json.loads)
reconstructed_table = pa.Table.from_pandas(df)
writer.write_table(reconstructed_table)
with ArrowWriter(path=destination, stream=stream) as writer:
for source in tqdm(sources, unit="sources", disable=not_verbose):
pf = pa.parquet.ParquetFile(source)
for i in tqdm(range(pf.num_row_groups), unit="row_groups", leave=False, disable=not_verbose):
df = pf.read_row_group(i).to_pandas()
for col in df.columns:
df[col] = df[col].apply(json.loads)
reconstructed_table = pa.Table.from_pandas(df)
writer.write_table(reconstructed_table)
return destination
27 changes: 13 additions & 14 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,18 +968,18 @@ def _prepare_split(self, split_generator):

fname = "{}-{}.arrow".format(self.name, split_generator.name)
fpath = os.path.join(self._cache_dir, fname)
writer = ArrowWriter(features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size)

generator = self._generate_examples(**split_generator.gen_kwargs)
not_verbose = bool(logger.getEffectiveLevel() > WARNING)
try:
for key, record in utils.tqdm(
generator, unit=" examples", total=split_info.num_examples, leave=False, disable=not_verbose
):
example = self.info.features.encode_example(record)
writer.write(example)
finally:
num_examples, num_bytes = writer.finalize()
with ArrowWriter(features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size) as writer:
try:
for key, record in utils.tqdm(
generator, unit=" examples", total=split_info.num_examples, leave=False, disable=not_verbose
):
example = self.info.features.encode_example(record)
writer.write(example)
finally:
num_examples, num_bytes = writer.finalize()

assert num_examples == num_examples, f"Expected to write {split_info.num_examples} but wrote {num_examples}"
split_generator.split_info.num_examples = num_examples
Expand Down Expand Up @@ -1026,13 +1026,12 @@ def _prepare_split(self, split_generator):
fname = "{}-{}.arrow".format(self.name, split_generator.name)
fpath = os.path.join(self._cache_dir, fname)

writer = ArrowWriter(path=fpath)

generator = self._generate_tables(**split_generator.gen_kwargs)
not_verbose = bool(logger.getEffectiveLevel() > WARNING)
for key, table in utils.tqdm(generator, unit=" tables", leave=False, disable=not_verbose):
writer.write_table(table)
num_examples, num_bytes = writer.finalize()
with ArrowWriter(path=fpath) as writer:
for key, table in utils.tqdm(generator, unit=" tables", leave=False, disable=not_verbose):
writer.write_table(table)
num_examples, num_bytes = writer.finalize()

split_generator.split_info.num_examples = num_examples
split_generator.split_info.num_bytes = num_bytes
Expand Down
Loading