Skip to content

Commit

Permalink
add distributed property in accelerate_state
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Jul 5, 2022
1 parent 6ebddcd commit c0c83fb
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def collate_fn(examples):
predictions, references = accelerator.gather(
(predictions, batch["labels"])
) # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/multi_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def training_function(config, args):
predictions, references = accelerator.gather((predictions, batch["labels"]))
# New Code #
# First we check if it's a distributed system
if accelerator.num_processes > 1:
if accelerator.use_distributed:
# Then see if we're on the last batch of our eval dataloader
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
Expand Down
2 changes: 1 addition & 1 deletion examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def training_function(config, args):
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["label"]))
if accelerator.num_processes > 1:
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader) - samples_seen]
references = references[: len(eval_dataloader) - samples_seen]
Expand Down
2 changes: 1 addition & 1 deletion examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def collate_fn(examples):
predictions, references = accelerator.gather(
(predictions, batch["labels"])
) # If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
Expand Down
6 changes: 5 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def __init__(
if self.rng_types is None:
self.rng_types = ["torch"] if is_torch_version("<=", "1.5.1") else ["generator"]

@property
def use_distributed(self):
return self.distributed_type != DistributedType.NO

@property
def distributed_type(self):
return self.state.distributed_type
Expand Down Expand Up @@ -361,7 +365,7 @@ def no_sync(self, model):
PyTorch Module that was prepared with `Accelerator.prepare`
"""
context = contextlib.nullcontext
if self.num_processes > 1:
if self.use_distributed:
context = getattr(model, "no_sync", context)

with context():
Expand Down

0 comments on commit c0c83fb

Please sign in to comment.