-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Include bf16 support for TPUs and CPUs, and a better check for if a CUDA device supports BF16 #462
Changes from 4 commits
0ab016c
c18f309
1f647b0
15f9a06
0dc916a
c0388b3
2a69fb6
106b105
f6fc60a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,14 @@ | |
from accelerate.data_loader import prepare_data_loader | ||
from accelerate.state import AcceleratorState | ||
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors | ||
from accelerate.utils import DistributedType, gather, is_torch_version, set_seed, synchronize_rng_states | ||
from accelerate.utils import ( | ||
DistributedType, | ||
gather, | ||
is_bf16_available, | ||
is_torch_version, | ||
set_seed, | ||
synchronize_rng_states, | ||
) | ||
|
||
|
||
def init_state_check(): | ||
|
@@ -245,74 +252,77 @@ def training_check(): | |
|
||
accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.") | ||
|
||
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 | ||
print("FP16 training check.") | ||
AcceleratorState._reset_state() | ||
accelerator = Accelerator(mixed_precision="fp16") | ||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | ||
model = RegressionModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
set_seed(42) | ||
generator.manual_seed(42) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
model = accelerator.unwrap_model(model).cpu() | ||
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | ||
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | ||
|
||
# TEST that previous fp16 flag still works | ||
print("Legacy FP16 training check.") | ||
AcceleratorState._reset_state() | ||
accelerator = Accelerator(fp16=True) | ||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | ||
model = RegressionModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
set_seed(42) | ||
generator.manual_seed(42) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
model = accelerator.unwrap_model(model).cpu() | ||
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | ||
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | ||
|
||
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 | ||
print("BF16 training check.") | ||
AcceleratorState._reset_state() | ||
accelerator = Accelerator(mixed_precision="bf16") | ||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | ||
model = RegressionModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
set_seed(42) | ||
generator.manual_seed(42) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
model = accelerator.unwrap_model(model).cpu() | ||
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | ||
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | ||
if torch.cuda.is_available(): | ||
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 | ||
print("FP16 training check.") | ||
AcceleratorState._reset_state() | ||
accelerator = Accelerator(mixed_precision="fp16") | ||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | ||
model = RegressionModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
set_seed(42) | ||
generator.manual_seed(42) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
model = accelerator.unwrap_model(model).cpu() | ||
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | ||
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | ||
|
||
# TEST that previous fp16 flag still works | ||
print("Legacy FP16 training check.") | ||
Comment on lines
+279
to
+280
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we need to keep this. We have done a couple of releases since we deprecated it, so it's okay if we stop testing it IMO. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd feel more comfortable dropping the test once we've removed entirely the legacy param (whenever that may be) |
||
AcceleratorState._reset_state() | ||
accelerator = Accelerator(fp16=True) | ||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | ||
model = RegressionModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
set_seed(42) | ||
generator.manual_seed(42) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
model = accelerator.unwrap_model(model).cpu() | ||
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | ||
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | ||
|
||
# BF16 support is only for CPU + TPU, and some GPU | ||
if is_bf16_available(): | ||
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 | ||
print("BF16 training check.") | ||
AcceleratorState._reset_state() | ||
accelerator = Accelerator(mixed_precision="bf16") | ||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | ||
model = RegressionModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | ||
|
||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | ||
set_seed(42) | ||
generator.manual_seed(42) | ||
for _ in range(3): | ||
for batch in train_dl: | ||
model.zero_grad() | ||
output = model(batch["x"]) | ||
loss = torch.nn.functional.mse_loss(output, batch["y"]) | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
|
||
model = accelerator.unwrap_model(model).cpu() | ||
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | ||
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | ||
|
||
|
||
def main(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to be extra sure that this always exists for PyTorch version for which
is_bf16_available()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by adding a torch check for >= 1.10