Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Torch distributed #3529

Merged
merged 9 commits into from
Dec 17, 2019
12 changes: 5 additions & 7 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params, Tqdm
from allennlp.common.util import prepare_environment, lazy_groups_of
from allennlp.common.util import prepare_environment
from allennlp.data import Vocabulary, DataIterator
from allennlp.models import Model
from allennlp.training import Trainer
Expand Down Expand Up @@ -223,6 +223,7 @@ def find_learning_rate_model(
train_data = all_datasets["train"]

trainer_params = params.pop("trainer")

no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
Expand Down Expand Up @@ -296,10 +297,7 @@ def search_learning_rate(

trainer.model.train()

num_gpus = len(trainer._cuda_devices)

raw_train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_batches)

learning_rates = []
Expand All @@ -310,7 +308,7 @@ def search_learning_rate(
else:
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)

for i, batch_group in enumerate(train_generator_tqdm):
for i, batch in enumerate(train_generator_tqdm):

if linear_steps:
current_lr = start_lr + (lr_update_factor * i)
Expand All @@ -321,7 +319,7 @@ def search_learning_rate(
param_group["lr"] = current_lr

trainer.optimizer.zero_grad()
loss = trainer.batch_loss(batch_group, for_training=True)
loss = trainer.batch_loss(batch, for_training=True)
loss.backward()
loss = loss.detach().cpu().item()

Expand Down
2 changes: 1 addition & 1 deletion allennlp/commands/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def fine_tune_model(
model,
test_data,
validation_iterator or iterator,
cuda_device=trainer._cuda_devices[0],
cuda_device=trainer.cuda_device,
batch_weight_key=batch_weight_key,
)

Expand Down
Loading