Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch lora kernels post model load #2345

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 68 additions & 87 deletions src/axolotl/monkeypatch/lora_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import inspect
import logging
import types
from typing import Type

import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
from transformers.modeling_utils import PreTrainedModel

from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
Expand Down Expand Up @@ -96,108 +95,90 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
return attn_output


def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
# pylint: disable=protected-access
def patch_self_attn_lora(model: PreTrainedModel):
"""
Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Patches the attention classes in a transformer model with optimized LoRA implementations.

Args:
cfg: Dictionary mapping `axolotl` config keys to values.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.

Returns:
The appropriate attention class for the model.
Args:
model: A HuggingFace transformers model.

Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")

# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Find all attention modules in the model
attention_modules = [
module
for module in model.modules()
if "attention" in module.__class__.__name__.lower()
and hasattr(module, "forward")
]

if not attention_modules:
LOG.warning("No attention modules found in model")
return

# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
attention_classes = {type(module) for module in attention_modules}
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")

return Qwen2Attention
for attention_cls in attention_classes:
# Skip if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
continue

try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
# Get and store original forward implementation
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward

return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# Remove indentation
self_attn_forward, _ = detab_code(self_attn_forward)

# Verify required code blocks exist
assert (
ORIGINAL_QKV_CODE in self_attn_forward
), f"Original QKV code not found in {attention_cls.__name__}"
assert (
ORIGINAL_O_CODE in self_attn_forward
), f"Original O code not found in {attention_cls.__name__}"

# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
# Replace code blocks
self_attn_forward = self_attn_forward.replace(
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)

It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
# Import necessary symbols from the attention module
module_name = attention_cls.__module__
module = importlib.import_module(module_name)

Args:
cfg: Dictionary mapping `axolotl` config keys to values.
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)

Raises:
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
attention_cls = get_attention_cls_from_config(cfg)
if items_to_import:
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)

# Check if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
return
# Execute the new implementation
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102

self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)

assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"

self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)

# Load necessary imports
module_name = attention_cls.__module__
module = importlib.import_module(module_name)

items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102

LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)


def apply_lora_kernel_patches(
Expand Down
16 changes: 8 additions & 8 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,6 @@ def apply_patches(self) -> None:

patch_mistral_cross_entropy()

if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora

patch_self_attn_lora(self.cfg)

def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
Expand Down Expand Up @@ -1028,6 +1023,12 @@ def apply_unsloth_lora_patch(self) -> None:
integrate_rope_embeddings()

def apply_lora_patch(self) -> None:
"""Applies patching relevant to LoRA Triton kernels if enabled."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora

patch_self_attn_lora(self.model)

if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
Expand Down Expand Up @@ -1181,6 +1182,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)

# TODO: Deprecate this.
self.apply_unsloth_lora_patch()
self.apply_lora_patch()

Expand All @@ -1201,9 +1203,7 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
"""Load a model for a given configuration and tokenizer."""
loader = ModelLoader(
cfg,
tokenizer,
Expand Down
78 changes: 39 additions & 39 deletions tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention

from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
from axolotl.utils.dict import DictDefault

MODEL_CONFIGS = [
Expand Down Expand Up @@ -65,15 +63,45 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))


def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# pylint: disable=duplicate-code
@pytest.fixture
def minimal_cfg():
"Config of real HuggingFace Hub model"
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)

return cfg


def test_attention_patching_integration(minimal_cfg):
"""Test attention patching in integration context."""
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")

# Apply patch
patch_self_attn_lora(cfg)
# Load model
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)

# Get the new forward method
patched_forward = LlamaAttention.forward
Expand Down Expand Up @@ -376,38 +404,10 @@ def test_model_architecture(model_config):


# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(minimal_cfg):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer

# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)

# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)

# Verify correct activation function
layer = model.model.model.layers[0]
Expand Down
Loading