Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include bf16 support for TPUs and CPUs, and a better check for if a CUDA device supports BF16 #462

Merged
merged 9 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
extract_model_from_parallel,
gather,
get_pretty_name,
is_bf16_available,
is_deepspeed_available,
is_torch_version,
is_tpu_available,
Expand Down Expand Up @@ -242,20 +243,24 @@ def __init__(
# Mixed precision attributes
self.scaler = None
self.native_amp = False
err = "{mode} mixed precision requires {requirement}"
if self.state.mixed_precision == "fp16":
self.native_amp = is_torch_version(">=", "1.6")
if not self.native_amp:
raise ValueError("fp16 mixed precision requires PyTorch >= 1.6")

raise ValueError(err.format(mode="fp16", requirement="PyTorch >= 1.6"))
if not torch.cuda.is_available():
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
elif self.state.mixed_precision == "bf16":
self.native_amp = is_torch_version(">=", "1.10")
self.native_amp = is_bf16_available(True)
if mixed_precision == "bf16" and not self.native_amp:
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10")
raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device."))

kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
# Only on the GPU do we care about scaling the gradients
if torch.cuda.is_available():
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
self.scaler = torch.cuda.amp.GradScaler(**kwargs)

# Internal references to the training objects
self._optimizers = []
Expand Down Expand Up @@ -528,8 +533,9 @@ def prepare_model(self, model):
if self.native_amp:
if self.mixed_precision == "fp16" and is_torch_version(">=", "1.10"):
model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward)
elif self.mixed_precision == "bf16":
model.forward = torch.cuda.amp.autocast(dtype=torch.bfloat16)(model.forward)
elif self.mixed_precision == "bf16" and self.distributed_type != DistributedType.TPU:
device_type = "cuda" if torch.cuda.is_available() else "cpu"
model.forward = torch.autocast(device_type=device_type, dtype=torch.bfloat16)(model.forward)
else:
model.forward = torch.cuda.amp.autocast()(model.forward)
model.forward = convert_outputs_to_fp32(model.forward)
Expand Down Expand Up @@ -1054,8 +1060,10 @@ def autocast(self):
if self.native_amp:
if self.mixed_precision == "fp16" and is_torch_version(">=", "1.10"):
autocast_context = torch.cuda.amp.autocast(dtype=torch.float16)
elif self.mixed_precision == "bf16":
autocast_context = torch.cuda.amp.autocast(dtype=torch.bfloat16)
elif self.mixed_precision == "bf16" and is_bf16_available():
if self.distributed_type in [DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]:
device_type = "cpu" if not torch.cuda.is_available() else "cuda"
autocast_context = torch.autocast(dtype=torch.bfloat16, device_type=device_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to be extra sure that this always exists for PyTorch version for which is_bf16_available()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by adding a torch check for >= 1.10

else:
autocast_context = torch.cuda.amp.autocast()

Expand Down
18 changes: 9 additions & 9 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def notebook_launcher(function, args=(), num_processes=None, use_fp16=False, mix
else:
in_colab_or_kaggle = False

try:
mixed_precision = PrecisionType(mixed_precision.lower())
except ValueError:
raise ValueError(
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
)

if in_colab_or_kaggle:
if os.environ.get("TPU_NAME", None) is not None:
# TPU launch
Expand All @@ -65,14 +72,14 @@ def notebook_launcher(function, args=(), num_processes=None, use_fp16=False, mix
num_processes = 8

launcher = PrepareForLaunch(function, distributed_type="TPU")
print(f"Launching a training on {num_processes} TPU cores.")
print(f"Launching a training on {num_processes} TPU cores")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")
else:
# No need for a distributed launch otherwise as it's either CPU or one GPU.
if torch.cuda.is_available():
print("Launching training on one GPU.")
else:
print("Launching training on CPU.")
print("Launching training on one CPU.")
function(*args)

else:
Expand Down Expand Up @@ -105,13 +112,6 @@ def notebook_launcher(function, args=(), num_processes=None, use_fp16=False, mix
"function."
)

try:
mixed_precision = PrecisionType(mixed_precision.lower())
except ValueError:
raise ValueError(
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
)

if use_fp16:
warnings.warn('use_fp16=True is deprecated. Use mixed_precision="fp16" instead.', DeprecationWarning)
mixed_precision = "fp16"
Expand Down
18 changes: 8 additions & 10 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ def __init__(
self.process_index = xm.get_ordinal()
self.local_process_index = xm.get_local_ordinal()
self.device = xm.xla_device()
self.mixed_precision = (
parse_choice_from_env("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
if mixed_precision == "bf16":
os.environ["XLA_USE_BF16"] = 1
else:
os.environ["XLA_USE_BF16"] = 0
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
self.mixed_precision = mixed_precision
elif os.environ.get("USE_DEEPSPEED", "false") == "true" and not cpu:
assert (
is_deepspeed_available()
Expand All @@ -120,9 +122,7 @@ def __init__(
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
self.device = torch.device("cuda", self.local_process_index)
torch.cuda.set_device(self.device)
self.mixed_precision = (
parse_choice_from_env("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
self.mixed_precision = mixed_precision
if os.environ.get("USE_FSDP", "false") == "true":
self.distributed_type = DistributedType.FSDP
if self.mixed_precision != "no":
Expand Down Expand Up @@ -166,15 +166,13 @@ def __init__(
self.process_index = torch.distributed.get_rank()
self.local_process_index = local_rank
self.device = torch.device("cpu")
self.mixed_precision = "no"
self.mixed_precision = mixed_precision
else:
self.distributed_type = DistributedType.NO
self.num_processes = 1
self.process_index = self.local_process_index = 0
self.device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
self.mixed_precision = (
parse_choice_from_env("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
self.mixed_precision = mixed_precision
self.initialized = True

def __repr__(self):
Expand Down
148 changes: 79 additions & 69 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from accelerate.data_loader import prepare_data_loader
from accelerate.state import AcceleratorState
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
from accelerate.utils import DistributedType, gather, is_torch_version, set_seed, synchronize_rng_states
from accelerate.utils import (
DistributedType,
gather,
is_bf16_available,
is_torch_version,
set_seed,
synchronize_rng_states,
)


def init_state_check():
Expand Down Expand Up @@ -245,74 +252,77 @@ def training_check():

accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.")

# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
print("FP16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="fp16")
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()

model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."

# TEST that previous fp16 flag still works
print("Legacy FP16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(fp16=True)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()

model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."

# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
print("BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16")
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()

model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."
if torch.cuda.is_available():
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
print("FP16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="fp16")
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()

model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."

# TEST that previous fp16 flag still works
print("Legacy FP16 training check.")
Comment on lines +279 to +280
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need to keep this. We have done a couple of releases since we deprecated it, so it's okay if we stop testing it IMO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd feel more comfortable dropping the test once we've removed entirely the legacy param (whenever that may be)

AcceleratorState._reset_state()
accelerator = Accelerator(fp16=True)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()

model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."

# BF16 support is only for CPU + TPU, and some GPU
if is_bf16_available():
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
print("BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16")
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()

model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."


def main():
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from .imports import (
is_apex_available,
is_bf16_available,
is_boto3_available,
is_ccl_available,
is_comet_ml_available,
Expand Down
13 changes: 13 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import importlib
import sys

import torch

from .versions import is_torch_version


# The package importlib_metadata is in a different place, depending on the Python version.
if sys.version_info < (3, 8):
Expand Down Expand Up @@ -68,6 +72,15 @@ def is_deepspeed_available():
return False


def is_bf16_available(ignore_tpu=True):
"Checks if bf16 is supported, optionally ignoring the TPU"
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
elif is_tpu_available():
return ignore_tpu
return is_torch_version(">=", "1.10")


def is_transformers_available():
return importlib.util.find_spec("transformers") is not None

Expand Down
12 changes: 6 additions & 6 deletions tests/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ def test_pre_forward_hook_is_executed(self):
test_hook = PreForwardHook()
add_hook_to_module(test_model, test_hook)
output1 = test_model(x)
self.assertTrue(torch.allclose(output1, expected))
self.assertTrue(torch.allclose(output1, expected, atol=1e-5))

# Attaching a hook to a model when it already has one replaces, does not chain
test_hook = PreForwardHook()
add_hook_to_module(test_model, test_hook)
output1 = test_model(x)
self.assertTrue(torch.allclose(output1, expected))
self.assertTrue(torch.allclose(output1, expected, atol=1e-5))

# You need to use the sequential hook to chain two or more hooks
test_hook = SequentialHook(PreForwardHook(), PreForwardHook())
add_hook_to_module(test_model, test_hook)

output2 = test_model(x)
assert torch.allclose(output2, expected2)
assert torch.allclose(output2, expected2, atol=1e-5)

def test_post_forward_hook_is_executed(self):
test_model = ModelForTest()
Expand All @@ -100,20 +100,20 @@ def test_post_forward_hook_is_executed(self):
test_hook = PostForwardHook()
add_hook_to_module(test_model, test_hook)
output1 = test_model(x)
self.assertTrue(torch.allclose(output1, output + 1))
self.assertTrue(torch.allclose(output1, output + 1, atol=1e-5))

# Attaching a hook to a model when it already has one replaces, does not chain
test_hook = PostForwardHook()
add_hook_to_module(test_model, test_hook)
output1 = test_model(x)
self.assertTrue(torch.allclose(output1, output + 1))
self.assertTrue(torch.allclose(output1, output + 1, atol=1e-5))

# You need to use the sequential hook to chain two or more hooks
test_hook = SequentialHook(PostForwardHook(), PostForwardHook())
add_hook_to_module(test_model, test_hook)

output2 = test_model(x)
assert torch.allclose(output2, output + 2)
assert torch.allclose(output2, output + 2, atol=1e-5)

def test_no_grad_in_hook(self):
test_model = ModelForTest()
Expand Down