Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[feature] support variety of Mixture of LoRA Experts PEFT methods (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee authored Aug 9, 2024
1 parent 6cd8dc2 commit 5e7372f
Show file tree
Hide file tree
Showing 47 changed files with 1,160 additions and 456 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ m-LoRA (short for Multi-LoRA) is an open-source LLMOps framework developed by th

- Support for multiple PEFT algorithms and various pre-trained models.

- Exclusive Mo-LoRA (Mixture of LoRAs) optimization for [MixLoRA](https://github.com/TUDB-Labs/MixLoRA).
- Mo-LoRA (Mixture of LoRAs) optimization, mainly for [MixLoRA](https://github.com/TUDB-Labs/MixLoRA).

You can try m-LoRA with [Google Colab](https://githubtocolab.com/mikecovlee/mLoRA/blob/main/misc/finetune-demo.ipynb) before local installation.

Expand Down Expand Up @@ -54,13 +54,19 @@ You can use the `MLORA_BACKEND_TYPE` environment variable to force m-LoRA to use
| | PEFT Methods | Arguments* |
|---------|----------------------------------------------------------|-----------------------------------------------------|
| ✓ | [QLoRA](https://arxiv.org/abs/2402.12354) | See *Quantize Methods* |
| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354) | `loraplus_lr_ratio: 20.0` |
| ✓ | [DoRA](https://arxiv.org/abs/2402.09353) | `use_dora: true` |
| ✓ | [rsLoRA](https://arxiv.org/abs/2312.03732) | `use_rslora: true` |
| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354) | `"loraplus_lr_ratio": 20.0` |
| ✓ | [DoRA](https://arxiv.org/abs/2402.09353) | `"use_dora": true` |
| ✓ | [rsLoRA](https://arxiv.org/abs/2312.03732) | `"use_rslora": true` |
| ✓ | [MoLA](https://arxiv.org/abs/2402.08562) | `"routing_strategy": "mola", "num_experts": 8` |
| ✓ | [LoRAMoE](https://arxiv.org/abs/2312.09979) | `"routing_strategy": "loramoe", "num_experts": 8` |
| ✓ | [MixLoRA](https://arxiv.org/abs/2404.15159) | See [MixLoRA](https://github.com/TUDB-Labs/MixLoRA) |

*: Arguments of configuration file

### Notice of PEFT supports
1. m-LoRA supports specific optimized operators for these PEFT methods, which can effectively improve the computing performance during training, evaluation and inference. However, these operators may cause a certain degree of accuracy loss (less than 5%). You can disable the optimized operators by defining the `MLORA_EVALUATE_MODE` environment variable in advance.
2. Auxiliary Loss is not currently supported for Mo-LoRA (Mixture of LoRAs) methods other than MixLoRA

## Supported Attention Methods

| | Attention Methods | Name | Arguments* |
Expand Down
5 changes: 5 additions & 0 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def inference(

mlora_backend.empty_cache()

if os.getenv("MLORA_EVALUATE_MODE") is None:
logging.info("Using efficient operators.")
else:
logging.info("Using deterministic operators.")

if args.inference:
inference(
model=model,
Expand Down
18 changes: 9 additions & 9 deletions mlora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from .backends import backend
from .common import (
from .dispatcher import Dispatcher, TrainTask
from .evaluator import EvaluateConfig, evaluate
from .generator import GenerateConfig, generate
from .model import LLMModel
from .modules import (
AdapterConfig,
Cache,
LLMBatchConfig,
LLMCache,
LLMForCausalLM,
LLMModelConfig,
LLMModelInput,
LLMModelOutput,
LoraConfig,
MixConfig,
MixLoraConfig,
cache_factory,
lora_config_factory,
)
from .dispatcher import Dispatcher, TrainTask
from .evaluator import EvaluateConfig, evaluate
from .generator import GenerateConfig, generate
from .model import LLMModel
from .prompter import Prompter
from .tokenizer import Tokenizer
from .trainer import TrainConfig, train
Expand All @@ -29,7 +29,7 @@
setup_logging()

__all__ = [
"Cache",
"LLMCache",
"cache_factory",
"LLMModelConfig",
"LLMModelOutput",
Expand All @@ -38,7 +38,7 @@
"LLMModelInput",
"AdapterConfig",
"LoraConfig",
"MixConfig",
"MixLoraConfig",
"lora_config_factory",
"TrainTask",
"Dispatcher",
Expand Down
2 changes: 1 addition & 1 deletion mlora/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import datasets

from .common import InputData, LLMBatchConfig, LLMModelInput, Masks, Tokens
from .modules import InputData, LLMBatchConfig, LLMModelInput, Masks, Tokens
from .tokenizer import Tokenizer


Expand Down
8 changes: 4 additions & 4 deletions mlora/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import torch

from .common import InputData, LLMBatchConfig, LLMModelInput, MixConfig, Prompt
from .model import LLMModel
from .modules import InputData, LLMBatchConfig, LLMModelInput, MixLoraConfig, Prompt
from .tasks import BasicMetric, BasicTask, CommonSenseTask, task_dict
from .tokenizer import Tokenizer

Expand Down Expand Up @@ -93,7 +93,7 @@ def reset_parameters(self):
def _prepare_tasks(model, tokenizer, configs):
for config in configs:
config.prepare(tokenizer, model.device_)
if not isinstance(model.adapter_configs_[config.adapter_name], MixConfig):
if not isinstance(model.adapter_configs_[config.adapter_name], MixLoraConfig):
continue
for layer in model.model_.layers_:
if config.adapter_name in layer.mlp_.moes_:
Expand Down Expand Up @@ -172,7 +172,7 @@ def _compute_metrcis(model, current_configs, sequence_lengths, batch_labels, out

if config.router_profile:
adapter_config = model.adapter_configs_[config.adapter_name]
if isinstance(adapter_config, MixConfig):
if isinstance(adapter_config, MixLoraConfig):
router_statistic_ = list(0 for _ in range(adapter_config.num_experts_))
for layer in model.model_.layers_:
if config.adapter_name not in layer.mlp_.moes_:
Expand Down Expand Up @@ -225,7 +225,7 @@ def _compute_result(model, configs, save_file):
result["metrics"] = compute_results
if config.router_profile:
adapter_config = model.adapter_configs_[config.adapter_name]
if isinstance(adapter_config, MixConfig):
if isinstance(adapter_config, MixLoraConfig):
router_statistic_ = list(0 for _ in range(adapter_config.num_experts_))
for layer in model.model_.layers_:
if config.adapter_name not in layer.mlp_.moes_:
Expand Down
2 changes: 1 addition & 1 deletion mlora/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch

from mlora.backends import backend
from mlora.common import LLMBatchConfig, LLMModelInput, Tokens, cache_factory
from mlora.model import LLMModel
from mlora.modules import LLMBatchConfig, LLMModelInput, Tokens, cache_factory
from mlora.prompter import Prompter
from mlora.tokenizer import Tokenizer

Expand Down
Loading

0 comments on commit 5e7372f

Please sign in to comment.