From c0de4539c7a382af39d2545450b3cc1e5ce946c0 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 5 Jul 2022 18:17:02 +0800 Subject: [PATCH 1/8] add some useful decorators --- src/accelerate/decorator.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/accelerate/decorator.py diff --git a/src/accelerate/decorator.py b/src/accelerate/decorator.py new file mode 100644 index 00000000000..9c1d36e38a0 --- /dev/null +++ b/src/accelerate/decorator.py @@ -0,0 +1,27 @@ +from functools import wraps + + +def on_main_process(func): + """ + Run func on main process only + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.is_main_process or not self.use_distributed: + return func(self, *args, **kwargs) + + return wrapper + + +def on_local_main_process(func): + """ + Run func on local main process only + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.is_local_main_process or not self.use_distributed: + return func(self, *args, **kwargs) + + return wrapper From 09dbdc4e4ff9457650e479bb9b18f64f11378aeb Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 21 Jul 2022 20:44:11 +0800 Subject: [PATCH 2/8] make on_(local_)main_process member of Accelerator --- src/accelerate/accelerator.py | 61 ++++++++++++++++++++++++----------- src/accelerate/decorator.py | 27 ---------------- 2 files changed, 43 insertions(+), 45 deletions(-) delete mode 100644 src/accelerate/decorator.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 35afd6c5d9a..23787b9aca6 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -19,6 +19,7 @@ import sys import warnings from contextlib import contextmanager +from functools import wraps from typing import List, Optional, Union import torch @@ -356,23 +357,29 @@ def mixed_precision(self): mixed_precision = self.state.mixed_precision return mixed_precision - @contextmanager - def local_main_process_first(self): + def on_main_process(func): """ - Lets the local main process go inside a with block. - - The other processes will enter the with block after the main process exits. + Run func on main process only """ - yield from self._goes_first(self.is_local_main_process) - @contextmanager - def main_process_first(self): - """ - Lets the main process go first inside a with block. + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.is_main_process or not self.use_distributed: + return func(self, *args, **kwargs) - The other processes will enter the with block after the main process exits. + return wrapper + + def on_local_main_process(func): + """ + Run func on local main process only """ - yield from self._goes_first(self.is_main_process) + + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.is_local_main_process or not self.use_distributed: + return func(self, *args, **kwargs) + + return wrapper def _goes_first(self, is_main): if not is_main: @@ -383,6 +390,24 @@ def _goes_first(self, is_main): if is_main: self.wait_for_everyone() + @contextmanager + def main_process_first(self): + """ + Lets the main process go first inside a with block. + + The other processes will enter the with block after the main process exits. + """ + yield from self._goes_first(self.is_main_process) + + @contextmanager + def local_main_process_first(self): + """ + Lets the local main process go inside a with block. + + The other processes will enter the with block after the main process exits. + """ + yield from self._goes_first(self.is_local_main_process) + @contextmanager def no_sync(self, model): """ @@ -991,6 +1016,7 @@ def init_trackers(self, project_name: str, config: Optional[dict] = None, init_k for tracker in self.trackers: tracker.store_init_configuration(config) + @on_main_process def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dict] = {}): """ Logs `values` to all stored trackers in `self.trackers`. @@ -1007,17 +1033,16 @@ def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dic {"wandb": {"tags": ["tag_a", "tag_b"]}} ``` """ - if self.is_main_process: - for tracker in self.trackers: - tracker.log(values, step=step, **log_kwargs.get(tracker.name, {})) + for tracker in self.trackers: + tracker.log(values, step=step, **log_kwargs.get(tracker.name, {})) + @on_main_process def end_training(self): """ Runs any special end training behaviors, such as stopping trackers """ - if self.is_main_process: - for tracker in self.trackers: - tracker.finish() + for tracker in self.trackers: + tracker.finish() def save(self, obj, f): """ diff --git a/src/accelerate/decorator.py b/src/accelerate/decorator.py deleted file mode 100644 index 9c1d36e38a0..00000000000 --- a/src/accelerate/decorator.py +++ /dev/null @@ -1,27 +0,0 @@ -from functools import wraps - - -def on_main_process(func): - """ - Run func on main process only - """ - - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.is_main_process or not self.use_distributed: - return func(self, *args, **kwargs) - - return wrapper - - -def on_local_main_process(func): - """ - Run func on local main process only - """ - - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.is_local_main_process or not self.use_distributed: - return func(self, *args, **kwargs) - - return wrapper From 18725226a14ce8e0521c9a74db601de2b6345bfd Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 21 Jul 2022 20:45:50 +0800 Subject: [PATCH 3/8] update examples --- examples/by_feature/fsdp_with_peak_mem_tracking.py | 7 +++---- examples/by_feature/tracking.py | 7 +++---- examples/complete_cv_example.py | 11 +++++------ examples/complete_nlp_example.py | 11 +++++------ 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/examples/by_feature/fsdp_with_peak_mem_tracking.py b/examples/by_feature/fsdp_with_peak_mem_tracking.py index 8e3b9e1308a..09199154ac1 100644 --- a/examples/by_feature/fsdp_with_peak_mem_tracking.py +++ b/examples/by_feature/fsdp_with_peak_mem_tracking.py @@ -110,10 +110,9 @@ def training_function(config, args): batch_size = int(config["batch_size"]) # We need to initialize the trackers we use, and also store our configuration - if args.with_tracking: - if accelerator.is_main_process: - experiment_config = vars(args) - accelerator.init_trackers("fsdp_glue_no_trainer", experiment_config) + if args.with_tracking and accelerator.is_main_process: + experiment_config = vars(args) + accelerator.init_trackers("fsdp_glue_no_trainer", experiment_config) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) datasets = load_dataset("glue", "mrpc") diff --git a/examples/by_feature/tracking.py b/examples/by_feature/tracking.py index 78d4af422fb..00c8665ddbe 100644 --- a/examples/by_feature/tracking.py +++ b/examples/by_feature/tracking.py @@ -162,10 +162,9 @@ def training_function(config, args): # New Code # # We need to initalize the trackers we use. Overall configurations can also be stored - if args.with_tracking: - if accelerator.is_main_process: - run = os.path.split(__file__)[-1].split(".")[0] - accelerator.init_trackers(run, config) + if args.with_tracking and accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run, config) # Now we train the model for epoch in range(num_epochs): diff --git a/examples/complete_cv_example.py b/examples/complete_cv_example.py index 8f893b7c4d7..17cc83170fa 100644 --- a/examples/complete_cv_example.py +++ b/examples/complete_cv_example.py @@ -103,12 +103,11 @@ def training_function(config, args): checkpointing_steps = None # We need to initialize the trackers we use, and also store our configuration - if args.with_tracking: - if accelerator.is_main_process: - run = os.path.split(__file__)[-1].split(".")[0] - if args.logging_dir: - run = os.path.join(args.logging_dir, run) - accelerator.init_trackers(run, config) + if args.with_tracking and accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + if args.logging_dir: + run = os.path.join(args.logging_dir, run) + accelerator.init_trackers(run, config) # Grab all the image filenames file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(".jpg")] diff --git a/examples/complete_nlp_example.py b/examples/complete_nlp_example.py index 572bc9a2370..dc0cf43ed26 100644 --- a/examples/complete_nlp_example.py +++ b/examples/complete_nlp_example.py @@ -75,12 +75,11 @@ def training_function(config, args): batch_size = int(config["batch_size"]) # We need to initialize the trackers we use, and also store our configuration - if args.with_tracking: - if accelerator.is_main_process: - run = os.path.split(__file__)[-1].split(".")[0] - if args.logging_dir: - run = os.path.join(args.logging_dir, run) - accelerator.init_trackers(run, config) + if args.with_tracking and accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + if args.logging_dir: + run = os.path.join(args.logging_dir, run) + accelerator.init_trackers(run, config) tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") datasets = load_dataset("glue", "mrpc") From af539e714dee7aa162736eb77b268da0d08af765 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 21 Jul 2022 20:55:56 +0800 Subject: [PATCH 4/8] add on_process and on_local_process --- src/accelerate/accelerator.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 23787b9aca6..5420e8720a8 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -381,6 +381,36 @@ def wrapper(self, *args, **kwargs): return wrapper + def on_process(process_idx): + """ + Run func on certain process only + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.process_idx == process_idx or not self.use_distributed: + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + def on_process(local_process_idx): + """ + Run func on certain local process only + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.local_process_idx == local_process_idx or not self.use_distributed: + return func(self, *args, **kwargs) + + return wrapper + + return decorator + def _goes_first(self, is_main): if not is_main: self.wait_for_everyone() From 6675543b5800824fcf7d48af448d334647f09c04 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 22 Jul 2022 00:40:48 +0800 Subject: [PATCH 5/8] fixes wrong name for `on_local_process` --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 5420e8720a8..0774381da5d 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -396,7 +396,7 @@ def wrapper(self, *args, **kwargs): return decorator - def on_process(local_process_idx): + def on_local_process(local_process_idx): """ Run func on certain local process only """ From 0642a72fac43e25b4b4a34f1f2b3b434d2db81c2 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 22 Jul 2022 20:49:15 +0800 Subject: [PATCH 6/8] Update src/accelerate/accelerator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0774381da5d..e146799515a 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -359,7 +359,7 @@ def mixed_precision(self): def on_main_process(func): """ - Run func on main process only + A decorator that will run the decorated function on the main process only. """ @wraps(func) From ef711c4f4d3532d317c1fdb7bd127f3998fa65c4 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 22 Jul 2022 20:49:21 +0800 Subject: [PATCH 7/8] Update src/accelerate/accelerator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index e146799515a..a4804ba8034 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -371,7 +371,7 @@ def wrapper(self, *args, **kwargs): def on_local_main_process(func): """ - Run func on local main process only + A decorator that will run the decorated function on the local main process only. """ @wraps(func) From 4711b43452523bb7b4d753a6a75a4b1b45f1c6cd Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 22 Jul 2022 20:49:33 +0800 Subject: [PATCH 8/8] Update src/accelerate/accelerator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a4804ba8034..5f32a83d19a 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -383,7 +383,7 @@ def wrapper(self, *args, **kwargs): def on_process(process_idx): """ - Run func on certain process only + A decorator that will run the decorated function on a given process index only. """ def decorator(func):