Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add on_main_process decorators #488

Merged
merged 8 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ def training_function(config, args):
batch_size = int(config["batch_size"])

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
if accelerator.is_main_process:
experiment_config = vars(args)
accelerator.init_trackers("fsdp_glue_no_trainer", experiment_config)
if args.with_tracking and accelerator.is_main_process:
experiment_config = vars(args)
accelerator.init_trackers("fsdp_glue_no_trainer", experiment_config)

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
datasets = load_dataset("glue", "mrpc")
Expand Down
7 changes: 3 additions & 4 deletions examples/by_feature/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,9 @@ def training_function(config, args):

# New Code #
# We need to initalize the trackers we use. Overall configurations can also be stored
if args.with_tracking:
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run, config)
if args.with_tracking and accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run, config)

# Now we train the model
for epoch in range(num_epochs):
Expand Down
11 changes: 5 additions & 6 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ def training_function(config, args):
checkpointing_steps = None

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
if args.logging_dir:
run = os.path.join(args.logging_dir, run)
accelerator.init_trackers(run, config)
if args.with_tracking and accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
if args.logging_dir:
run = os.path.join(args.logging_dir, run)
accelerator.init_trackers(run, config)

# Grab all the image filenames
file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(".jpg")]
Expand Down
11 changes: 5 additions & 6 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ def training_function(config, args):
batch_size = int(config["batch_size"])

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
if args.logging_dir:
run = os.path.join(args.logging_dir, run)
accelerator.init_trackers(run, config)
if args.with_tracking and accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
if args.logging_dir:
run = os.path.join(args.logging_dir, run)
accelerator.init_trackers(run, config)

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
datasets = load_dataset("glue", "mrpc")
Expand Down
87 changes: 71 additions & 16 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import warnings
from contextlib import contextmanager
from functools import wraps
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -356,23 +357,59 @@ def mixed_precision(self):
mixed_precision = self.state.mixed_precision
return mixed_precision

@contextmanager
def local_main_process_first(self):
def on_main_process(func):
"""
A decorator that will run the decorated function on the main process only.
"""
Lets the local main process go inside a with block.

The other processes will enter the with block after the main process exits.
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.is_main_process or not self.use_distributed:
return func(self, *args, **kwargs)

return wrapper

def on_local_main_process(func):
"""
A decorator that will run the decorated function on the local main process only.
"""
yield from self._goes_first(self.is_local_main_process)

@contextmanager
def main_process_first(self):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.is_local_main_process or not self.use_distributed:
return func(self, *args, **kwargs)

return wrapper

def on_process(process_idx):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def on_process(process_idx):
def on_process(process_idx, local=False):

Maybe we could group this one and the text in one decorator since it's one that takes arguments?

"""
A decorator that will run the decorated function on a given process index only.
"""
Lets the main process go first inside a with block.

The other processes will enter the with block after the main process exits.
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.process_idx == process_idx or not self.use_distributed:
return func(self, *args, **kwargs)

return wrapper

return decorator

def on_local_process(local_process_idx):
"""
yield from self._goes_first(self.is_main_process)
Run func on certain local process only
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.local_process_idx == local_process_idx or not self.use_distributed:
return func(self, *args, **kwargs)

return wrapper

return decorator

def _goes_first(self, is_main):
if not is_main:
Expand All @@ -383,6 +420,24 @@ def _goes_first(self, is_main):
if is_main:
self.wait_for_everyone()

@contextmanager
def main_process_first(self):
"""
Lets the main process go first inside a with block.

The other processes will enter the with block after the main process exits.
"""
yield from self._goes_first(self.is_main_process)

@contextmanager
def local_main_process_first(self):
"""
Lets the local main process go inside a with block.

The other processes will enter the with block after the main process exits.
"""
yield from self._goes_first(self.is_local_main_process)

@contextmanager
def no_sync(self, model):
"""
Expand Down Expand Up @@ -991,6 +1046,7 @@ def init_trackers(self, project_name: str, config: Optional[dict] = None, init_k
for tracker in self.trackers:
tracker.store_init_configuration(config)

@on_main_process
def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dict] = {}):
"""
Logs `values` to all stored trackers in `self.trackers`.
Expand All @@ -1007,17 +1063,16 @@ def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dic
{"wandb": {"tags": ["tag_a", "tag_b"]}}
```
"""
if self.is_main_process:
for tracker in self.trackers:
tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))
for tracker in self.trackers:
tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))

@on_main_process
def end_training(self):
"""
Runs any special end training behaviors, such as stopping trackers
"""
if self.is_main_process:
for tracker in self.trackers:
tracker.finish()
for tracker in self.trackers:
tracker.finish()

def save(self, obj, f):
"""
Expand Down