diff --git a/requirements.txt b/requirements.txt index 91e253c532..ffcecfe799 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ tqdm omegaconf # Quantization -torchao-nightly==2024.3.25 +torchao-nightly==2024.3.29 diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index 7fc590ac17..c81945abe8 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -43,11 +43,13 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32"): "log_every_n_steps=1", ] - def _fetch_expected_loss_values(self, run_qlora: bool = False): - if run_qlora: + def _fetch_expected_loss_values(self): + return [10.5074, 10.5614, 10.5205, 10.4918] + + def _fetch_qlora_expected_loss_values(self, dtype): + if dtype == "bf16": return [10.5057, 10.5575, 10.5179, 10.4898] - else: - return [10.5074, 10.5614, 10.5205, 10.4918] + return [10.5059, 10.5571, 10.5181, 10.4897] @pytest.mark.integration_test def test_loss(self, tmpdir, monkeypatch): @@ -82,13 +84,14 @@ def test_loss(self, tmpdir, monkeypatch): runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) - expected_loss_values = self._fetch_expected_loss_values(run_qlora=False) + expected_loss_values = self._fetch_expected_loss_values() torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) @pytest.mark.integration_test - def test_loss_qlora(self, tmpdir, monkeypatch): + @pytest.mark.parametrize("dtype", ["fp32", "bf16"]) + def test_loss_qlora(self, dtype, tmpdir, monkeypatch): ckpt = "small_test_ckpt_meta" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent @@ -114,14 +117,13 @@ def test_loss_qlora(self, tmpdir, monkeypatch): lora_alpha=16, ) - # TODO (rohan-varma): QLoRA only supported with bf16 for now - cmd = cmd + self._get_test_config_overrides(dtype_str="bf16") + model_config + cmd = cmd + self._get_test_config_overrides(dtype_str=dtype) + model_config monkeypatch.setattr(sys, "argv", cmd) with pytest.raises(SystemExit): runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) - expected_loss_values = self._fetch_expected_loss_values(run_qlora=True) + expected_loss_values = self._fetch_qlora_expected_loss_values(dtype=dtype) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) diff --git a/tests/torchtune/models/test_lora_llama2.py b/tests/torchtune/models/test_lora_llama2.py index 9509455f0e..3452b7a2e5 100644 --- a/tests/torchtune/models/test_lora_llama2.py +++ b/tests/torchtune/models/test_lora_llama2.py @@ -203,8 +203,9 @@ def test_lora_linear_quantize_base(self): if isinstance(module, LoRALinear): assert module._quantize_base - def test_qlora_llama2_parity(self, inputs): - with utils.set_default_dtype(torch.bfloat16): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_llama2_parity(self, dtype, inputs): + with utils.set_default_dtype(dtype): model_ref = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -212,7 +213,7 @@ def test_qlora_llama2_parity(self, inputs): vocab_size=50, quantize_base=False, embed_dim=512, - dtype=torch.bfloat16, + dtype=dtype, ) qlora = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], @@ -221,7 +222,7 @@ def test_qlora_llama2_parity(self, inputs): vocab_size=50, quantize_base=True, embed_dim=512, - dtype=torch.bfloat16, + dtype=dtype, ) qlora_sd = qlora.state_dict() model_ref.load_state_dict(qlora_sd) @@ -232,8 +233,9 @@ def test_qlora_llama2_parity(self, inputs): output = qlora(inputs) torch.testing.assert_close(ref_output, output) - def test_qlora_llama2_state_dict(self): - with utils.set_default_dtype(torch.bfloat16): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_llama2_state_dict(self, dtype): + with utils.set_default_dtype(dtype): model_ref = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -241,11 +243,11 @@ def test_qlora_llama2_state_dict(self): vocab_size=50, quantize_base=False, embed_dim=512, - dtype=torch.bfloat16, + dtype=dtype, ) - bf16_sd = model_ref.state_dict() - for v in bf16_sd.values(): - assert v.dtype == torch.bfloat16 + high_prec_sd = model_ref.state_dict() + for v in high_prec_sd.values(): + assert v.dtype == dtype # ensure quantized LoRA can load a bf16 state_dict qlora = self.get_lora_llama2( @@ -255,9 +257,9 @@ def test_qlora_llama2_state_dict(self): vocab_size=50, quantize_base=True, embed_dim=512, - dtype=torch.bfloat16, + dtype=dtype, ) - qlora.load_state_dict(bf16_sd) + qlora.load_state_dict(high_prec_sd) # LoRALinear base weights should be nf4 still for module in qlora.modules(): if isinstance(module, LoRALinear): @@ -265,10 +267,11 @@ def test_qlora_llama2_state_dict(self): # saved state_dict should have bf16 weights. qlora_sd = qlora.state_dict() for v in qlora_sd.values(): - assert v.dtype == torch.bfloat16 + assert v.dtype == dtype - def test_qlora_llama2_merged_state_dict(self): - with utils.set_default_dtype(torch.bfloat16): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_llama2_merged_state_dict(self, dtype): + with utils.set_default_dtype(dtype): qlora = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -276,7 +279,7 @@ def test_qlora_llama2_merged_state_dict(self): vocab_size=50, quantize_base=True, embed_dim=512, - dtype=torch.bfloat16, + dtype=dtype, reset_norm=False, # to ensure norm.scale key exists ) @@ -286,10 +289,10 @@ def test_qlora_llama2_merged_state_dict(self): for v in merged_ckpt.values(): # paranoid check for both, as NF4Tensor had issue where NF4Tensor.dtype would return bf16 assert not isinstance(v, NF4Tensor) - assert v.dtype == torch.bfloat16 + assert v.dtype == dtype # Ensure checkpoint can be loaded into non-LoRA model - with utils.set_default_dtype(torch.bfloat16): + with utils.set_default_dtype(dtype): llama2 = self.get_ref_llama2(vocab_size=50, embed_dim=512) llama2.load_state_dict(merged_ckpt) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 1f35092216..668defa915 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -48,57 +48,53 @@ def test_bias_unsupported(self): with pytest.raises(RuntimeError, match="does not currently support biases"): _ = FrozenNF4Linear(1, 1, bias=True) - def test_non_bf16_unsupported(self): - with pytest.raises(RuntimeError, match="only supported with bf16"): - _ = FrozenNF4Linear(1, 1, dtype=torch.float32) - - def test_parameters(self): - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_parameters(self, dtype): + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) params = list(nf4_linear.parameters()) assert len(params) == 1 assert isinstance(params[0], NF4Tensor) - def test_state_dict(self): - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_state_dict(self, dtype): + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) state_dict = nf4_linear.state_dict() assert len(state_dict) == 1 assert isinstance(state_dict["weight"], NF4Tensor) - def test_frozen_nf4_linear(self): - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16) - assert isinstance(nf4_linear.weight, NF4Tensor) - assert torch.bfloat16 == nf4_linear.weight.get_original_weight().dtype - - def test_output_bf16(self): - # Test to ensure W4 A16 produces A16 - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16) - inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_output_dtype(self, dtype): + # Test to ensure W4 A16 produces A16 / W4A32 produces A32 + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) out = nf4_linear(inp) - assert out.dtype == torch.bfloat16 + assert out.dtype == dtype - def test_backward_bf16(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_backward_dtype(self, dtype): # Test to ensure backward pass gives activation a bf16 gradient and no gradient # to the linear's weight, as it is frozen. - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=torch.bfloat16) - inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) nf4_linear(inp).sum().backward() - assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 + assert inp.grad is not None and inp.grad.dtype == dtype assert nf4_linear.weight.grad is None @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - def test_nf4_reconstruction_vs_bnb(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_nf4_reconstruction_vs_bnb(self, dtype): """ Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when reconstructing the respective original weights. """ dim = 512 - nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=torch.bfloat16) + nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) orig_weight = nf4_linear.weight.get_original_weight().clone().detach() bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 bnb_reconstruction = bnb_nf4_linear( - torch.eye(dim, dim, dtype=torch.bfloat16, device="cuda") + torch.eye(dim, dim, dtype=dtype, device="cuda") ) # Ensure nf4_linear and bnb reconstructions are close to each other. diff = ( @@ -107,18 +103,19 @@ def test_nf4_reconstruction_vs_bnb(self): assert diff.item() < 1e-2 @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - def test_nf4_bnb_linear(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_nf4_bnb_linear(self, dtype): """ This test ensures that nf4_linear is "no worse" than BNB by ensuring the error compared to a bf16 linear is not more than BNB's implementation. """ dim = 512 - nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=torch.bfloat16) + nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) orig_weight = nf4_linear.weight.get_original_weight().clone().detach() bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) - bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=torch.bfloat16) + bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) - inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda") + inp = torch.randn(2, 512, dtype=dtype, device="cuda") out_nf4 = nf4_linear(inp) out_bnb = bnb_nf4_linear(inp) diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index e87fc072bd..80c253f04c 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -13,7 +13,7 @@ from torch import nn from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune import utils -from torchtune.modules.low_precision import reparametrize_as_bf16_state_dict_post_hook +from torchtune.modules.low_precision import reparametrize_as_dtype_state_dict_post_hook from torchtune.modules.peft import LoRALinear from torchtune.utils.seed import set_seed @@ -111,8 +111,9 @@ def test_quantize_with_bias_raises(self): quantize_base=True, ) - def test_qlora_parity(self): - with utils.set_default_dtype(torch.bfloat16): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_qlora_parity(self, dtype): + with utils.set_default_dtype(dtype): qlora_linear = LoRALinear( in_dim=512, out_dim=512, @@ -130,20 +131,21 @@ def test_qlora_parity(self): quantize_base=False, ) - # set weight of lora_linear to unquantized bf16 of qlora_linear and check + # set weight of lora_linear to unquantized weight of qlora_linear and check # parity. - lora_linear.weight.data = qlora_linear.weight.get_original_weight() + lora_linear.weight.data = qlora_linear.weight.to(dtype) # Ensure forward passes are the same. This is because LoRALinear should use a special - # quantized linear operator that runs compute in bf16 (but only saves the 4 bit quantized tensor) + # quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor) # for autograd. - inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=torch.bfloat16) + inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype) lora_linear_out = lora_linear(inputs) qlora_linear_out = qlora_linear(inputs) torch.testing.assert_close(lora_linear_out, qlora_linear_out) - def test_quantized_state_dict_bf16(self): - with utils.set_default_dtype(torch.bfloat16): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_quantized_state_dict(self, dtype): + with utils.set_default_dtype(dtype): lora_linear = LoRALinear( in_dim=512, out_dim=512, @@ -154,12 +156,16 @@ def test_quantized_state_dict_bf16(self): ) lora_linear._register_state_dict_hook( - partial(reparametrize_as_bf16_state_dict_post_hook, offload_to_cpu=False) + partial( + reparametrize_as_dtype_state_dict_post_hook, + dtype=dtype, + offload_to_cpu=False, + ) ) sd = lora_linear.state_dict() - # No nf4 tensors, all bf16 + # No nf4 tensors, all have type dtype for v in sd.values(): - assert v.dtype == torch.bfloat16 + assert v.dtype == dtype assert not isinstance(v, NF4Tensor) # Load back in results in re-quant and creates the same nf4 tensor. @@ -177,7 +183,7 @@ def test_quantized_state_dict_bf16(self): to_nf4( torch.zeros_like( lora_linear.weight.get_original_weight(), - dtype=torch.bfloat16, + dtype=dtype, device=lora_linear.weight.device, ) ) diff --git a/torchtune/models/llama2/_component_builders.py b/torchtune/models/llama2/_component_builders.py index 476fc45bb7..9abd305c6e 100644 --- a/torchtune/models/llama2/_component_builders.py +++ b/torchtune/models/llama2/_component_builders.py @@ -21,7 +21,7 @@ TransformerDecoderLayer, ) -from torchtune.modules.low_precision import reparametrize_as_bf16_state_dict_post_hook +from torchtune.modules.low_precision import reparametrize_as_dtype_state_dict_post_hook from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear @@ -40,6 +40,7 @@ # ------------------ Vanilla Llama2 ------------------ + def llama2( vocab_size: int, num_layers: int, @@ -96,7 +97,9 @@ def llama2( max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) - hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = ( + intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + ) mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerDecoderLayer( attn=self_attn, @@ -117,6 +120,7 @@ def llama2( output=output_proj, ) + def llama2_mlp(dim: int, hidden_dim: int) -> FeedForward: """ Build the MLP layer associated with the Llama model. @@ -127,7 +131,6 @@ def llama2_mlp(dim: int, hidden_dim: int) -> FeedForward: return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) - # ------------------ LoRA Llama2 ------------------ @@ -206,7 +209,9 @@ def lora_llama2( quantize_base=quantize_base, ) - hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = ( + intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + ) if apply_lora_to_mlp: mlp = lora_llama2_mlp( dim=embed_dim, @@ -245,10 +250,16 @@ def lora_llama2( ) if quantize_base: - # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly # so as to not increase peak memory model._register_state_dict_hook( - partial(reparametrize_as_bf16_state_dict_post_hook, offload_to_cpu=True) + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) ) return model diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 62ddb8e55d..59fd4d1a8e 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -182,5 +182,7 @@ def lora_llama2_13b( lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=0.05, - quantize_base=False, + quantize_base=quantize_base, ) + +qlora_llama2_13b = partial(lora_llama2_13b, quantize_base=True) diff --git a/torchtune/modules/low_precision/__init__.py b/torchtune/modules/low_precision/__init__.py index b68bdc6611..2847a98ccd 100644 --- a/torchtune/modules/low_precision/__init__.py +++ b/torchtune/modules/low_precision/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._state_dict_hooks import reparametrize_as_bf16_state_dict_post_hook +from ._state_dict_hooks import reparametrize_as_dtype_state_dict_post_hook from .nf4_linear import FrozenNF4Linear -__all__ = ["FrozenNF4Linear", "reparametrize_as_bf16_state_dict_post_hook"] +__all__ = ["FrozenNF4Linear", "reparametrize_as_dtype_state_dict_post_hook"] diff --git a/torchtune/modules/low_precision/_state_dict_hooks.py b/torchtune/modules/low_precision/_state_dict_hooks.py index c79985b3ce..9133451608 100644 --- a/torchtune/modules/low_precision/_state_dict_hooks.py +++ b/torchtune/modules/low_precision/_state_dict_hooks.py @@ -6,24 +6,27 @@ from typing import Any, Dict, Tuple +import torch + import torch.nn as nn from torchao.dtypes.nf4tensor import NF4Tensor -def reparametrize_as_bf16_state_dict_post_hook( +def reparametrize_as_dtype_state_dict_post_hook( model: nn.Module, state_dict: Dict[str, Any], *args: Tuple[Any, ...], + dtype: torch.dtype = torch.bfloat16, offload_to_cpu: bool = True, **kwargs: Dict[Any, Any], ): """ A state_dict hook that replaces nf4 tensors with their restored - bf16 weight and optionally offloads the restored weight to CPU. + higher-precision weight and optionally offloads the restored weight to CPU. This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e. >>> m = MyModule() - >>> m._register_state_dict_hook(reparametrize_as_bf16_state_dict_post_hook) + >>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook) If the hook is registered per the above process, this hook will be called _after_ the module's ``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing @@ -33,11 +36,12 @@ def reparametrize_as_bf16_state_dict_post_hook( model (nn.Module): the model to take ``state_dict()`` on state_dict (Dict[str, Any]): the state dict to modify *args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook. - offload_to_cpu (bool): whether to offload the restored weight to CPU + dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``. + offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``. **kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook. """ for k, v in state_dict.items(): if isinstance(v, NF4Tensor): - state_dict[k] = v.get_original_weight() + state_dict[k] = v.to(dtype) if offload_to_cpu: state_dict[k] = state_dict[k].cpu() diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 32abdead9c..8c297cef5b 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -20,8 +20,6 @@ class FrozenNF4Linear(nn.Linear): and is meant to be used as the base Linear layer for modeling use cases such as QLoRA where base model parameters are frozen. NOTE: biases are currently not supported. - NOTE: This class always creates the underlying full precision weight as bf16 dtypte. Note that - this will override the default PyTorch dtype that is set via `torch.set_default_dtype`. Args: in_dim (int): input dimension @@ -32,7 +30,6 @@ class FrozenNF4Linear(nn.Linear): Raises: RuntimeError: if ``bias`` is set to ``True`` - RuntimeError: if ``dtype`` is not set to ``torch.bfloat16`` """ def __init__( @@ -41,27 +38,16 @@ def __init__( if "bias" in kwargs and kwargs.pop("bias"): raise RuntimeError("FrozenNF4Linear does not currently support biases!") - if "dtype" in kwargs: - kwargs_dtype = kwargs.pop("dtype") - if kwargs_dtype != torch.bfloat16: - raise RuntimeError( - "FrozenNF4Linear is only supported with bf16 parameter currently." - ) - super().__init__( - in_dim, out_dim, device=device, dtype=torch.bfloat16, bias=False, **kwargs - ) + super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs) self.weight.requires_grad_(False) self.nf4_weight = to_nf4(self.weight.data) # re-register self.weight as the nf4 weight, so that the nf4 weight # shows up as expected in .parameters, state_dict, etc. self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) - # TODO: likely need to handle state_dict save & load via hooks to properly manage - # types. - def forward(self, input: Tensor) -> Tensor: """ - Runs linear operation with input tensor as given by `input`. Computation happens in bf16 + Runs linear operation with input tensor as given by `input`. Computation happens in higher precision, though only the nf4 weight is saved for backward for gradient computation to ensure additional memory is not used. Args: diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 35ba9307a2..2345341654 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch import nn, Tensor + from torchao.dtypes.nf4tensor import linear_nf4 from torchtune.modules.low_precision import ( # noqa: F401 _register_nf4_dispatch_ops,