Skip to content

Commit

Permalink
Merge pull request #6 from microsoft/dev/mallamanis/amp
Browse files Browse the repository at this point in the history
Support AMP training in GNNs and other neural models.
  • Loading branch information
Miltos authored Nov 10, 2020
2 parents e1c507f + 3bfe1f8 commit 0ba8308
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
max-parallel: 2
matrix:
python-version: [3.6, 3.7, 3.8]
torch-version: [1.4.0, 1.5.0]
torch-version: [1.6.0, 1.7.0]

steps:
- uses: actions/checkout@v1
Expand Down
32 changes: 20 additions & 12 deletions ptgnn/baseneuralmodel/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
clip_gradient_norm: Optional[float] = None,
target_validation_metric: Optional[str] = None,
target_validation_metric_higher_is_better: bool = False,
enable_amp: bool = False,
):
"""
:param model: The Component to be built and trained
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
self.__train_epoch_end_hooks: List[EndOfEpochHook] = []
self.__validation_epoch_end_hooks: List[EndOfEpochHook] = []
self.__clip_gradient_norm = clip_gradient_norm
self.__enable_amp = enable_amp

self.__target_metric = target_validation_metric
if target_validation_metric is not None:
Expand Down Expand Up @@ -182,6 +184,8 @@ def _run_training(
sum_epoch_loss, running_avg_loss, num_minibatches, num_samples = 0.0, 0.0, 0, 0
start_time = time.time()
self.neural_module.train()

scaler = torch.cuda.amp.GradScaler(enabled=self.__enable_amp)
with tqdm(desc="Training", disable=not show_progress_bar, leave=False) as progress_bar:
for step_idx, (mb_data, raw_samples) in enumerate(
self.__model.minibatch_iterator(
Expand All @@ -194,20 +198,23 @@ def _run_training(
)
):
optimizer.zero_grad()
mb_loss = self.neural_module(**mb_data)
mb_loss.backward()
with torch.cuda.amp.autocast(enabled=self.__enable_amp):
mb_loss = self.neural_module(**mb_data)
if torch.isnan(mb_loss):
raise Exception("Loss has a NaN value.")

if self.__clip_gradient_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.neural_module.parameters(recurse=True), self.__clip_gradient_norm
)
scaler.scale(mb_loss).backward()

if torch.isnan(mb_loss):
raise Exception("Loss has a NaN value.")
if self.__clip_gradient_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.neural_module.parameters(recurse=True), self.__clip_gradient_norm
)

optimizer.step()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
scaler.step(optimizer)
scaler.update()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)

num_minibatches += 1
num_samples += len(raw_samples)
Expand Down Expand Up @@ -258,7 +265,8 @@ def _run_validation(
shuffle_input=False,
parallelize=parallelize,
):
mb_loss = self.neural_module(**mb_data)
with torch.cuda.amp.autocast(enabled=self.__enable_amp):
mb_loss = self.neural_module(**mb_data)
num_minibatches += 1
num_samples += len(raw_samples)
sum_epoch_loss += float(mb_loss.cpu())
Expand Down
2 changes: 2 additions & 0 deletions ptgnn/implementations/graph2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Options:
--aml Run this in Azure ML
--amp Enable automatic mixed precision.
--azure-info=<path> Azure authentication information file (JSON). Used to load data from Azure storage.
--max-num-epochs=<epochs> The maximum number of epochs to run training for. [default: 100]
--minibatch-size=<size> The minibatch size. [default: 300]
Expand Down Expand Up @@ -109,6 +110,7 @@ def create_mp_layers(num_edges: int):
model_path,
max_num_epochs=int(arguments["--max-num-epochs"]),
minibatch_size=int(arguments["--minibatch-size"]),
enable_amp=arguments["--amp"],
)
if nn is not None:
trainer.neural_module = nn
Expand Down
1 change: 1 addition & 0 deletions ptgnn/implementations/graph2seq/trainandtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Options:
--aml Run this in Azure ML
--amp Enable automatic mixed precision.
--azure-info=<path> Azure authentication information file (JSON). Used to load data from Azure storage.
--max-num-epochs=<epochs> The maximum number of epochs to run training for. [default: 100]
--minibatch-size=<size> The minibatch size. [default: 300]
Expand Down
2 changes: 2 additions & 0 deletions ptgnn/implementations/typilus/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Options:
--aml Run this in Azure ML
--amp Enable automatic mixed precision.
--azure-info=<path> Azure authentication information file (JSON). Used to load data from Azure storage.
--max-num-epochs=<epochs> The maximum number of epochs to run training for. [default: 100]
--minibatch-size=<size> The minibatch size. [default: 300]
Expand Down Expand Up @@ -167,6 +168,7 @@ def create_optimizer(parameters):
clip_gradient_norm=1,
target_validation_metric="Accuracy",
target_validation_metric_higher_is_better=True,
enable_amp=arguments["--amp"],
)
if nn is not None:
trainer.neural_module = nn
Expand Down
10 changes: 8 additions & 2 deletions ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ def _aggregate_messages(
self, messages: torch.Tensor, message_targets: torch.Tensor, num_nodes, aggregation_fn: str
):
"""Utility function to be used by concrete implementors."""
# Support AMP
msg_dtype = messages.dtype
return scatter(
messages, index=message_targets, dim=0, dim_size=num_nodes, reduce=aggregation_fn
)
messages.to(torch.float32),
index=message_targets,
dim=0,
dim_size=num_nodes,
reduce=aggregation_fn,
).to(msg_dtype)

@property
@abstractmethod
Expand Down

0 comments on commit 0ba8308

Please sign in to comment.