Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Torch distributed (#3529)
Browse files Browse the repository at this point in the history
* Logging and metrics changes for distributed training (#3372)

* Refactor logging setup to support distributed attrs

* `cleanup_logging()` is replaced with stdlib's `logging.shutdown()`
* Remove `TeeLogger` and use standard log handlers
* Remove `replace_cr_with_newline` and use the standard logging practice of using
`logging.Filter`
* Introduce `rank` and `world_size` optional attributes to support
distributed workers

* Support for distributed training in `get_metrics`

* Remove bad import

* Fix duplicate log messages in stdout

* Remove preemptive `logging.shutdown`

`logging.shutdown` is called by the logging module
by default during exit which makes it unnecessary to
be called from `train_model`

* Fix black formatting issues

* Remove `tee_logger` references in API doc

* Set log level from `ALLENNLP_DEBUG` env

* Changes to `train_model` for distributed training support (#3390)

* High level API changes to support distributed training

* Fix flake8 error

* Fix mypy error

* Add docstring and misc fixes

* Fix flake tests

* `Trainer` changes for distributed training (#3414)

Followup PR to #3390 and #3372 to bring in distributed training support. Following are the major changes done:

* Workers are spawned using `mp.spawn` and each worker creates its own `Trainer` instance
* `Trainer.__init__` wraps up `self.model` with `DistributedDataParallel`
*  Logging and metric aggregation are already done in the previous PRs
* `Vocabulary` creation in case of distributed training is done before spawning the workers and creating `Trainer` class

To run distributed training, the trainer needs to have the following flag to be enabled:

```jsonnet
{
    "trainer": {
        "distributed": true,
        // ...
    }
}
```

TODO:
* Try to reproduce comparable results and share extensive results for existing/selected models
* Check if other commands like `evaluate`, `predict`, `fine-tune` works well with the new changes
* Should all the callbacks need to be called from every worker in case callback based training?
* Should the current dataset readers be changed to support distributed training as well?(to selectively yield data based on their rank)
* Write tests - _would be happy to get some suggestions on how to write tests for this_

* Dist tests (#3515)

* add some tests

* another test, fix incorrect type annotations

* torch mp uses it's own context, no need to set default

* lint

* strip out old DP stuff, ensure multiple cuda devices raises err… (#3516)

* strip out old DP stuff, ensure multiple cuda devices raises errors

* lint

* remove unused attribute

* remove _cuda_devices everywhere

* fixes

* move distributed config up to top level

* lint

* clean up

* rename occurences of batch_group

* remove hack from find_learning_rate

* fix last tests

* black

* use a top level distributed config

* correct error for int

* change up parse_cuda_devices to raise good error and be strongly typed

* fix merge
  • Loading branch information
DeNeutoy authored Dec 17, 2019
1 parent 41b5a4d commit ca453c8
Show file tree
Hide file tree
Showing 22 changed files with 800 additions and 434 deletions.
12 changes: 5 additions & 7 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params, Tqdm
from allennlp.common.util import prepare_environment, lazy_groups_of
from allennlp.common.util import prepare_environment
from allennlp.data import Vocabulary, DataIterator
from allennlp.models import Model
from allennlp.training import Trainer
Expand Down Expand Up @@ -223,6 +223,7 @@ def find_learning_rate_model(
train_data = all_datasets["train"]

trainer_params = params.pop("trainer")

no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
Expand Down Expand Up @@ -296,10 +297,7 @@ def search_learning_rate(

trainer.model.train()

num_gpus = len(trainer._cuda_devices)

raw_train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_batches)

learning_rates = []
Expand All @@ -310,7 +308,7 @@ def search_learning_rate(
else:
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)

for i, batch_group in enumerate(train_generator_tqdm):
for i, batch in enumerate(train_generator_tqdm):

if linear_steps:
current_lr = start_lr + (lr_update_factor * i)
Expand All @@ -321,7 +319,7 @@ def search_learning_rate(
param_group["lr"] = current_lr

trainer.optimizer.zero_grad()
loss = trainer.batch_loss(batch_group, for_training=True)
loss = trainer.batch_loss(batch, for_training=True)
loss.backward()
loss = loss.detach().cpu().item()

Expand Down
2 changes: 1 addition & 1 deletion allennlp/commands/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def fine_tune_model(
model,
test_data,
validation_iterator or iterator,
cuda_device=trainer._cuda_devices[0],
cuda_device=trainer.cuda_device,
batch_weight_key=batch_weight_key,
)

Expand Down
Loading

0 comments on commit ca453c8

Please sign in to comment.