Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp32 support for QLoRA #595

Merged
merged 16 commits into from
Apr 2, 2024
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ tqdm
omegaconf

# Quantization
torchao-nightly==2024.3.25
torchao-nightly==2024.3.29
20 changes: 11 additions & 9 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32"):
"log_every_n_steps=1",
]

def _fetch_expected_loss_values(self, run_qlora: bool = False):
if run_qlora:
def _fetch_expected_loss_values(self):
return [10.5074, 10.5614, 10.5205, 10.4918]

def _fetch_qlora_expected_loss_values(self, dtype):
if dtype == "bf16":
return [10.5057, 10.5575, 10.5179, 10.4898]
else:
return [10.5074, 10.5614, 10.5205, 10.4918]
return [10.5059, 10.5571, 10.5181, 10.4897]

@pytest.mark.integration_test
def test_loss(self, tmpdir, monkeypatch):
Expand Down Expand Up @@ -82,13 +84,14 @@ def test_loss(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

loss_values = get_loss_values_from_metric_logger(log_file)
expected_loss_values = self._fetch_expected_loss_values(run_qlora=False)
expected_loss_values = self._fetch_expected_loss_values()
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
)

@pytest.mark.integration_test
def test_loss_qlora(self, tmpdir, monkeypatch):
@pytest.mark.parametrize("dtype", ["fp32", "bf16"])
def test_loss_qlora(self, dtype, tmpdir, monkeypatch):
ckpt = "small_test_ckpt_meta"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand All @@ -114,14 +117,13 @@ def test_loss_qlora(self, tmpdir, monkeypatch):
lora_alpha=16,
)

# TODO (rohan-varma): QLoRA only supported with bf16 for now
cmd = cmd + self._get_test_config_overrides(dtype_str="bf16") + model_config
cmd = cmd + self._get_test_config_overrides(dtype_str=dtype) + model_config
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")

loss_values = get_loss_values_from_metric_logger(log_file)
expected_loss_values = self._fetch_expected_loss_values(run_qlora=True)
expected_loss_values = self._fetch_qlora_expected_loss_values(dtype=dtype)
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)
Expand Down
39 changes: 21 additions & 18 deletions tests/torchtune/models/test_lora_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,17 @@ def test_lora_linear_quantize_base(self):
if isinstance(module, LoRALinear):
assert module._quantize_base

def test_qlora_llama2_parity(self, inputs):
with utils.set_default_dtype(torch.bfloat16):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qlora_llama2_parity(self, dtype, inputs):
with utils.set_default_dtype(dtype):
model_ref = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
vocab_size=50,
quantize_base=False,
embed_dim=512,
dtype=torch.bfloat16,
dtype=dtype,
)
qlora = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"],
Expand All @@ -221,7 +222,7 @@ def test_qlora_llama2_parity(self, inputs):
vocab_size=50,
quantize_base=True,
embed_dim=512,
dtype=torch.bfloat16,
dtype=dtype,
)
qlora_sd = qlora.state_dict()
model_ref.load_state_dict(qlora_sd)
Expand All @@ -232,20 +233,21 @@ def test_qlora_llama2_parity(self, inputs):
output = qlora(inputs)
torch.testing.assert_close(ref_output, output)

def test_qlora_llama2_state_dict(self):
with utils.set_default_dtype(torch.bfloat16):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qlora_llama2_state_dict(self, dtype):
with utils.set_default_dtype(dtype):
model_ref = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
vocab_size=50,
quantize_base=False,
embed_dim=512,
dtype=torch.bfloat16,
dtype=dtype,
)
bf16_sd = model_ref.state_dict()
for v in bf16_sd.values():
assert v.dtype == torch.bfloat16
high_prec_sd = model_ref.state_dict()
for v in high_prec_sd.values():
assert v.dtype == dtype

# ensure quantized LoRA can load a bf16 state_dict
qlora = self.get_lora_llama2(
Expand All @@ -255,28 +257,29 @@ def test_qlora_llama2_state_dict(self):
vocab_size=50,
quantize_base=True,
embed_dim=512,
dtype=torch.bfloat16,
dtype=dtype,
)
qlora.load_state_dict(bf16_sd)
qlora.load_state_dict(high_prec_sd)
# LoRALinear base weights should be nf4 still
for module in qlora.modules():
if isinstance(module, LoRALinear):
assert isinstance(module.weight, NF4Tensor)
# saved state_dict should have bf16 weights.
qlora_sd = qlora.state_dict()
for v in qlora_sd.values():
assert v.dtype == torch.bfloat16
assert v.dtype == dtype

def test_qlora_llama2_merged_state_dict(self):
with utils.set_default_dtype(torch.bfloat16):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qlora_llama2_merged_state_dict(self, dtype):
with utils.set_default_dtype(dtype):
qlora = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
vocab_size=50,
quantize_base=True,
embed_dim=512,
dtype=torch.bfloat16,
dtype=dtype,
reset_norm=False, # to ensure norm.scale key exists
)

Expand All @@ -286,10 +289,10 @@ def test_qlora_llama2_merged_state_dict(self):
for v in merged_ckpt.values():
# paranoid check for both, as NF4Tensor had issue where NF4Tensor.dtype would return bf16
assert not isinstance(v, NF4Tensor)
assert v.dtype == torch.bfloat16
assert v.dtype == dtype

# Ensure checkpoint can be loaded into non-LoRA model
with utils.set_default_dtype(torch.bfloat16):
with utils.set_default_dtype(dtype):
llama2 = self.get_ref_llama2(vocab_size=50, embed_dim=512)

llama2.load_state_dict(merged_ckpt)
55 changes: 26 additions & 29 deletions tests/torchtune/modules/low_precision/test_nf4_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,57 +48,53 @@ def test_bias_unsupported(self):
with pytest.raises(RuntimeError, match="does not currently support biases"):
_ = FrozenNF4Linear(1, 1, bias=True)

def test_non_bf16_unsupported(self):
with pytest.raises(RuntimeError, match="only supported with bf16"):
_ = FrozenNF4Linear(1, 1, dtype=torch.float32)

def test_parameters(self):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_parameters(self, dtype):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
params = list(nf4_linear.parameters())
assert len(params) == 1
assert isinstance(params[0], NF4Tensor)

def test_state_dict(self):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_state_dict(self, dtype):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
state_dict = nf4_linear.state_dict()
assert len(state_dict) == 1
assert isinstance(state_dict["weight"], NF4Tensor)

def test_frozen_nf4_linear(self):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16)
assert isinstance(nf4_linear.weight, NF4Tensor)
assert torch.bfloat16 == nf4_linear.weight.get_original_weight().dtype

def test_output_bf16(self):
# Test to ensure W4 A16 produces A16
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16)
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_output_dtype(self, dtype):
# Test to ensure W4 A16 produces A16 / W4A32 produces A32
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
out = nf4_linear(inp)
assert out.dtype == torch.bfloat16
assert out.dtype == dtype

def test_backward_bf16(self):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_backward_dtype(self, dtype):
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
# to the linear's weight, as it is frozen.
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16)
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
nf4_linear(inp).sum().backward()
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
assert inp.grad is not None and inp.grad.dtype == dtype
assert nf4_linear.weight.grad is None

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_nf4_reconstruction_vs_bnb(self):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_nf4_reconstruction_vs_bnb(self, dtype):
"""
Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when
reconstructing the respective original weights.
"""
dim = 512
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=torch.bfloat16)
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype)
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight)

# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65
bnb_reconstruction = bnb_nf4_linear(
torch.eye(dim, dim, dtype=torch.bfloat16, device="cuda")
torch.eye(dim, dim, dtype=dtype, device="cuda")
)
# Ensure nf4_linear and bnb reconstructions are close to each other.
diff = (
Expand All @@ -107,18 +103,19 @@ def test_nf4_reconstruction_vs_bnb(self):
assert diff.item() < 1e-2

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_nf4_bnb_linear(self):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_nf4_bnb_linear(self, dtype):
"""
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
error compared to a bf16 linear is not more than BNB's implementation.
"""
dim = 512
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=torch.bfloat16)
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype)
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight)
bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=torch.bfloat16)
bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype)

inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")
inp = torch.randn(2, 512, dtype=dtype, device="cuda")

out_nf4 = nf4_linear(inp)
out_bnb = bnb_nf4_linear(inp)
Expand Down
32 changes: 19 additions & 13 deletions tests/torchtune/modules/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import nn
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune import utils
from torchtune.modules.low_precision import reparametrize_as_bf16_state_dict_post_hook
from torchtune.modules.low_precision import reparametrize_as_dtype_state_dict_post_hook
from torchtune.modules.peft import LoRALinear
from torchtune.utils.seed import set_seed

Expand Down Expand Up @@ -111,8 +111,9 @@ def test_quantize_with_bias_raises(self):
quantize_base=True,
)

def test_qlora_parity(self):
with utils.set_default_dtype(torch.bfloat16):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qlora_parity(self, dtype):
with utils.set_default_dtype(dtype):
qlora_linear = LoRALinear(
in_dim=512,
out_dim=512,
Expand All @@ -130,20 +131,21 @@ def test_qlora_parity(self):
quantize_base=False,
)

# set weight of lora_linear to unquantized bf16 of qlora_linear and check
# set weight of lora_linear to unquantized weight of qlora_linear and check
# parity.
lora_linear.weight.data = qlora_linear.weight.get_original_weight()
lora_linear.weight.data = qlora_linear.weight.to(dtype)

# Ensure forward passes are the same. This is because LoRALinear should use a special
# quantized linear operator that runs compute in bf16 (but only saves the 4 bit quantized tensor)
# quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor)
# for autograd.
inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=torch.bfloat16)
inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype)
lora_linear_out = lora_linear(inputs)
qlora_linear_out = qlora_linear(inputs)
torch.testing.assert_close(lora_linear_out, qlora_linear_out)

def test_quantized_state_dict_bf16(self):
with utils.set_default_dtype(torch.bfloat16):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_quantized_state_dict(self, dtype):
with utils.set_default_dtype(dtype):
lora_linear = LoRALinear(
in_dim=512,
out_dim=512,
Expand All @@ -154,12 +156,16 @@ def test_quantized_state_dict_bf16(self):
)

lora_linear._register_state_dict_hook(
partial(reparametrize_as_bf16_state_dict_post_hook, offload_to_cpu=False)
partial(
reparametrize_as_dtype_state_dict_post_hook,
dtype=dtype,
offload_to_cpu=False,
)
)
sd = lora_linear.state_dict()
# No nf4 tensors, all bf16
# No nf4 tensors, all have type dtype
for v in sd.values():
assert v.dtype == torch.bfloat16
assert v.dtype == dtype
assert not isinstance(v, NF4Tensor)

# Load back in results in re-quant and creates the same nf4 tensor.
Expand All @@ -177,7 +183,7 @@ def test_quantized_state_dict_bf16(self):
to_nf4(
torch.zeros_like(
lora_linear.weight.get_original_weight(),
dtype=torch.bfloat16,
dtype=dtype,
device=lora_linear.weight.device,
)
)
Expand Down
Loading
Loading