Skip to content

Commit

Permalink
torch.cuda.amp.autocast() -> torch.amp.autocast("cuda") (#1921)
Browse files Browse the repository at this point in the history
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
qgallouedec and qgallouedec authored Aug 12, 2024
1 parent 150a931 commit dbea3da
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 33 deletions.
14 changes: 7 additions & 7 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
Expand Down Expand Up @@ -1214,9 +1215,9 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager():
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs)

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
Expand All @@ -1243,9 +1244,8 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast

with generate_context_manager():
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with generate_context_manager:
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down Expand Up @@ -1302,8 +1302,8 @@ def prediction_step(
else:
ignore_keys = []

prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with torch.no_grad(), prediction_context_manager():
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs)

# force log the metrics
Expand Down
13 changes: 7 additions & 6 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
Expand Down Expand Up @@ -800,9 +801,9 @@ def compute_loss(
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)

compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager():
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# force log the metrics
Expand All @@ -817,9 +818,9 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with generate_context_manager():
with generate_context_manager:
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down Expand Up @@ -851,9 +852,9 @@ def prediction_step(
else:
ignore_keys = []

prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with torch.no_grad(), prediction_context_manager():
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

# force log the metrics
Expand Down
17 changes: 9 additions & 8 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
Expand Down Expand Up @@ -955,10 +956,10 @@ def null_ref_context(self):

def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
compte_ref_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

# compute reference logps
with torch.no_grad(), compte_ref_context_manager():
with torch.no_grad(), compte_ref_context_manager:
if self.ref_model is None:
with self.null_ref_context():
(
Expand Down Expand Up @@ -1416,9 +1417,9 @@ def compute_loss(
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)

compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager():
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
Expand All @@ -1435,9 +1436,9 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with generate_context_manager():
with generate_context_manager:
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down Expand Up @@ -1494,9 +1495,9 @@ def prediction_step(
else:
ignore_keys = []

prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with torch.no_grad(), prediction_context_manager():
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

# force log the metrics
Expand Down
13 changes: 7 additions & 6 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
Expand Down Expand Up @@ -1170,9 +1171,9 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager():
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs)

# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
Expand All @@ -1199,9 +1200,9 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with generate_context_manager():
with generate_context_manager:
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down Expand Up @@ -1258,8 +1259,8 @@ def prediction_step(
else:
ignore_keys = []

prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with torch.no_grad(), prediction_context_manager():
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs)

# force log the metrics
Expand Down
13 changes: 7 additions & 6 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
Expand Down Expand Up @@ -802,9 +803,9 @@ def compute_loss(
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)

compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager():
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

# force log the metrics
Expand All @@ -819,9 +820,9 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with generate_context_manager():
with generate_context_manager:
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down Expand Up @@ -853,9 +854,9 @@ def prediction_step(
else:
ignore_keys = []

prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with torch.no_grad(), prediction_context_manager():
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

# force log the metrics
Expand Down

0 comments on commit dbea3da

Please sign in to comment.