Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLora recipe failing on AMD MI250x #600

Closed
chauhang opened this issue Mar 27, 2024 · 3 comments
Closed

QLora recipe failing on AMD MI250x #600

chauhang opened this issue Mar 27, 2024 · 3 comments
Assignees

Comments

@chauhang
Copy link
Contributor

The QLora recipe is failing on AMD MI250 with RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Steps to reproduce:

  1. Install torchtune
  2. Download the model
  3. Run QLora recipe using:
    tune lora_finetune_single_device --config recipes/configs/llama2/7B_qlora_single_device.yaml epochs=2 max_steps_per_epoch=4

Error:

Full trace with details:

Running recipe_main with parameters {'model': {'_component_': 'torchtune.models.llama2.qlora_llama2_7b', 'lora_attn_modules': ['q_proj', 'v_proj', 'k_proj', 'output_proj'], 'apply_lora_to_mlp': True, 'apply_lora_to_output': False, 'lora_rank': 8, 'lora_alpha': 16, 'quantize_base': True}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelMetaCheckpointer', 'checkpoint_dir': '/tmp/llama2', 'checkpoint_files': ['consolidated.00.pth'], 'adapter_checkpoint': None, 'recipe_checkpoint': None, 'output_dir': '/tmp/llama2/', 'model_type': 'LLAMA2'}, 'resume_from_checkpoint': False, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': '/tmp/llama2/tokenizer.model'}, 'dataset': {'_component_': 'torchtune.datasets.alpaca_dataset', 'train_on_input': True, 'use_clean': True}, 'seed': None, 'shuffle': True, 'batch_size': 2, 'optimizer': {'_component_': 'torch.optim.AdamW', 'weight_decay': 0.01, 'lr': 0.0003}, 'lr_scheduler': {'_component_': 'torchtune.modules.get_cosine_schedule_with_warmup', 'num_warmup_steps': 100}, 'loss': {'_component_': 'torch.nn.CrossEntropyLoss'}, 'epochs': 2, 'max_steps_per_epoch': 4, 'gradient_accumulation_steps': 1, 'output_dir': '/tmp/qlora_finetune_output/', 'metric_logger': {'_component_': 'torchtune.utils.metric_logging.DiskLogger', 'log_dir': '${output_dir}'}, 'log_every_n_steps': 1, 'device': 'cuda', 'dtype': 'bf16', 'enable_activation_checkpointing': True}
2024-03-27:01:14:38,937 INFO     [_parse.py:52] Running recipe_main with parameters {'model': {'_component_': 'torchtune.models.llama2.qlora_llama2_7b', 'lora_attn_modules': ['q_proj', 'v_proj', 'k_proj', 'output_proj'], 'apply_lora_to_mlp': True, 'apply_lora_to_output': False, 'lora_rank': 8, 'lora_alpha': 16, 'quantize_base': True}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelMetaCheckpointer', 'checkpoint_dir': '/tmp/llama2', 'checkpoint_files': ['consolidated.00.pth'], 'adapter_checkpoint': None, 'recipe_checkpoint': None, 'output_dir': '/tmp/llama2/', 'model_type': 'LLAMA2'}, 'resume_from_checkpoint': False, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': '/tmp/llama2/tokenizer.model'}, 'dataset': {'_component_': 'torchtune.datasets.alpaca_dataset', 'train_on_input': True, 'use_clean': True}, 'seed': None, 'shuffle': True, 'batch_size': 2, 'optimizer': {'_component_': 'torch.optim.AdamW', 'weight_decay': 0.01, 'lr': 0.0003}, 'lr_scheduler': {'_component_': 'torchtune.modules.get_cosine_schedule_with_warmup', 'num_warmup_steps': 100}, 'loss': {'_component_': 'torch.nn.CrossEntropyLoss'}, 'epochs': 2, 'max_steps_per_epoch': 4, 'gradient_accumulation_steps': 1, 'output_dir': '/tmp/qlora_finetune_output/', 'metric_logger': {'_component_': 'torchtune.utils.metric_logging.DiskLogger', 'log_dir': '${output_dir}'}, 'log_every_n_steps': 1, 'device': 'cuda', 'dtype': 'bf16', 'enable_activation_checkpointing': True}
BF16 not supported on this hardware. Setting dtype to float32
2024-03-27:01:14:38,938 INFO     [precision.py:120] BF16 not supported on this hardware. Setting dtype to float32
Setting manual seed to local seed 1942403930. Local seed is seed + rank = 1942403930 + 0
2024-03-27:01:14:38,941 DEBUG    [seed.py:59] Setting manual seed to local seed 1942403930. Local seed is seed + rank = 1942403930 + 0
Writing logs to /tmp/qlora_finetune_output/log_1711527278.txt
Model is initialized with precision torch.float32.
2024-03-27:01:15:16,227 INFO     [lora_finetune_single_device.py:254] Model is initialized with precision torch.float32.

    Memory Stats after model init::
    GPU peak memory allocation: 7.92 GB
    GPU peak memory reserved: 9.18 GB
    GPU peak memory active: 7.92 GB
    
2024-03-27:01:15:16,228 INFO     [lora_finetune_single_device.py:255] 
    Memory Stats after model init::
    GPU peak memory allocation: 7.92 GB
    GPU peak memory reserved: 9.18 GB
    GPU peak memory active: 7.92 GB
    
Tokenizer is initialized from file.
2024-03-27:01:15:16,259 INFO     [lora_finetune_single_device.py:164] Tokenizer is initialized from file.
Optimizer and loss are initialized.
2024-03-27:01:15:16,263 INFO     [lora_finetune_single_device.py:269] Optimizer and loss are initialized.
Loss is initialized.
2024-03-27:01:15:16,263 INFO     [lora_finetune_single_device.py:174] Loss is initialized.
Dataset and Sampler are initialized.
2024-03-27:01:15:24,215 INFO     [lora_finetune_single_device.py:321] Dataset and Sampler are initialized.
Learning rate scheduler is initialized.
2024-03-27:01:15:24,216 INFO     [lora_finetune_single_device.py:285] Learning rate scheduler is initialized.
  0%|                                                                                                           | 0/25880 [00:00<?, ?it/s]/home/gchauhan/meta/torchtune/torchtune/modules/attention.py:208: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:515.)
  output = nn.functional.scaled_dot_product_attention(
1|1|Loss: 1.5469839572906494:   0%|                                                                             | 0/25880 [00:05<?, ?it/s][W327 01:15:30.099686146 Module.cpp:168] symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1...

1|1|Loss: 1.5469839572906494:   0%|                                                                             | 0/25880 [00:27<?, ?it/s]
Traceback (most recent call last):
  File "/home/gchauhan/my_envs/llm-amd/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/gchauhan/meta/torchtune/torchtune/_cli/tune.py", line 133, in main
    runpy.run_path(str(cmd), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/gchauhan/meta/torchtune/recipes/lora_finetune_single_device.py", line 459, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/home/gchauhan/meta/torchtune/torchtune/config/_parse.py", line 54, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/home/gchauhan/meta/torchtune/recipes/lora_finetune_single_device.py", line 454, in recipe_main
    recipe.train()
  File "/home/gchauhan/meta/torchtune/recipes/lora_finetune_single_device.py", line 424, in train
    loss.backward()
  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/gchauhan/my_envs/llm-amd/lib/python3.11/site-packages/torchao/dtypes/nf4tensor.py", line 574, in backward
    return grad_output @ weight.get_original_weight(), None
           ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
Exception raised from addmm_out_cuda_impl at ../aten/src/ATen/native/hip/Blas.cpp:216 (most recent call first):
C++ CapturedTraceback:
#4 c10::Error::Error(c10::SourceLocation, std::string) from ??:0
#5 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) from ??:0
#6 at::native::(anonymous namespace)::addmm_out_cuda_impl(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&, at::native::(anonymous namespace)::Activation) from Blas.cpp:0
#7 at::native::structured_mm_out_cuda::impl(at::Tensor const&, at::Tensor const&, at::Tensor const&) from ??:0
#8 at::(anonymous namespace)::wrapper_CUDA_mm(at::Tensor const&, at::Tensor const&) from RegisterCUDA.cpp:0
#9 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::wrapper_CUDA_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) from RegisterCUDA.cpp:0
#10 at::_ops::mm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) from ??:0
#11 torch::autograd::VariableType::(anonymous namespace)::mm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) from VariableType_3.cpp:0
#12 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) from VariableType_3.cpp:0
#13 at::_ops::mm::call(at::Tensor const&, at::Tensor const&) from ??:0
#14 at::native::_matmul_impl(at::Tensor&, at::Tensor const&, at::Tensor const&) from LinearAlgebra.cpp:0
#15 at::native::matmul(at::Tensor const&, at::Tensor const&) from ??:0
#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__matmul>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) from RegisterCompositeImplicitAutograd.cpp:0
#17 at::_ops::matmul::call(at::Tensor const&, at::Tensor const&) from ??:0
#18 torch::autograd::THPVariable_matmul(_object*, _object*, _object*) from python_variable_methods.cpp:0
#19 _object* torch::autograd::TypeError_to_NotImplemented_<&torch::autograd::THPVariable_matmul>(_object*, _object*, _object*) from python_variable_methods.cpp:0
#20 method_vectorcall_VARARGS_KEYWORDS from /usr/local/src/conda/python-3.11.5/Objects/descrobject.c:364
#21 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.11.5/Include/internal/pycore_call.h:92
#22 vectorcall_unbound from /usr/local/src/conda/python-3.11.5/Objects/typeobject.c:1641
#23 slot_nb_matrix_multiply from /usr/local/src/conda/python-3.11.5/Objects/typeobject.c:7422
#24 binary_op1 from /usr/local/src/conda/python-3.11.5/Objects/abstract.c:893
#25 binary_op from /usr/local/src/conda/python-3.11.5/Objects/abstract.c:932
#26 _PyEval_EvalFrameDefault from /usr/local/src/conda/python-3.11.5/Python/ceval.c:5553
#27 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.5/Include/internal/pycore_ceval.h:73
#28 do_call_core from /usr/local/src/conda/python-3.11.5/Python/ceval.c:7357
#29 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.5/Include/internal/pycore_ceval.h:73
#30 method_vectorcall from /usr/local/src/conda/python-3.11.5/Objects/classobject.c:89
#31 torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) from ??:0
#32 torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) from :0
#33 torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) from ??:0
#34 torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) from ??:0
#35 torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) from ??:0
#36 torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) from ??:0
#37 execute_native_thread_routine from thread48.o:0
#38 start_thread from ??:0
#39 __clone3 from :0

Environment

PyTorch version: 2.4.0.dev20240326+rocm6.0
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.0.32830-d62f6a171

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.34

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk1_zion_755_ga25447393a1d-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI250X / MI250 (gfx90a:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.0.32830
MIOpen runtime version: 3.0.0
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          256
On-line CPU(s) list:             0-255
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 7713 64-Core Processor
CPU family:                      25
Model:                           1
Thread(s) per core:              2
Core(s) per socket:              64
Socket(s):                       2
Stepping:                        1
Frequency boost:                 enabled
CPU(s) scaling MHz:              100%
CPU max MHz:                     2000.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        3992.39
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                  AMD-V
L1d cache:                       4 MiB (128 instances)
L1i cache:                       4 MiB (128 instances)
L2 cache:                        64 MiB (128 instances)
L3 cache:                        512 MiB (16 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-63,128-191
NUMA node1 CPU(s):               64-127,192-255

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] pytorch-triton-rocm==3.0.0+0a22a91d04
[pip3] torch==2.4.0.dev20240326+rocm6.0
[pip3] torchao-nightly==2024.3.25
[pip3] torchtune==0.0.1
[conda] numpy                     1.26.0                   pypi_0    pypi
[conda] pytorch-triton-rocm       3.0.0+0a22a91d04          pypi_0    pypi
[conda] torch                     2.4.0.dev20240326+rocm6.0          pypi_0    pypi
[conda] torchao-nightly           2024.3.25                pypi_0    pypi
[conda] torchtune                 0.0.1                    pypi_0    pypi
@rohan-varma
Copy link
Member

thanks for the detailed report! This seems because we fallback to fp32 instead of bf16, and then nf4 tensor's get_original_weight returns bf16 tensor, causing mismatch.

  1. Add fp32 support for QLoRA #595 should add fp32 support to QLoRA
  2. If b16 is indeed supported on this box, we should revise our check and actually enable this to run in bf16.

@rohan-varma rohan-varma self-assigned this Mar 27, 2024
@supernovae
Copy link
Contributor

supernovae commented Apr 18, 2024

@rohan-varma Is the fallback because cuda packaging is verified in the verify_bf16_support() calling and packaging.version.parse(torch.version.cuda).release >= (11, 0) which won't be true for rocm?

and packaging.version.parse(torch.version.cuda).release >= (11, 0)

my 7900xtx supports bf16 even though the check fails

import torch torch.cuda.is_bf16-supported()
True

i'll copy this to a new issue as i don't think we should check for cuda packages unless we can do an or to check for rocm > x version too

@rohan-varma
Copy link
Member

@supernovae Sorry for the late follow up, but this is awesome! Looks like you've verified support in #803. Closing this out now but feel free to reopen if issues persist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants