From 8fd223f853017d22b8c9b71696d9b706377c2302 Mon Sep 17 00:00:00 2001 From: Dengchun Li Date: Fri, 16 Aug 2024 18:03:48 +0800 Subject: [PATCH] [feature] support dynamic MixLoRA (#94) --- README.md | 23 ++- mlora/model.py | 2 +- mlora/modules/__init__.py | 4 + mlora/modules/config.py | 19 +- mlora/modules/lora_moes.py | 9 +- mlora/modules/mix_lora.py | 273 +++++++++++++++++++++++++--- mlora/tokenizer.py | 7 +- pyproject.toml | 2 +- templates/mixlora_dynamic.json | 37 ++++ templates/mixlora_dynamic_glm.json | 34 ++++ templates/mixlora_dynamic_phi.json | 36 ++++ templates/mixlora_dynamic_phi3.json | 34 ++++ 12 files changed, 433 insertions(+), 47 deletions(-) create mode 100644 templates/mixlora_dynamic.json create mode 100644 templates/mixlora_dynamic_glm.json create mode 100644 templates/mixlora_dynamic_phi.json create mode 100644 templates/mixlora_dynamic_phi3.json diff --git a/README.md b/README.md index 7437d944..79f69214 100644 --- a/README.md +++ b/README.md @@ -51,21 +51,24 @@ You can use the `MLORA_BACKEND_TYPE` environment variable to force m-LoRA to use ## Supported PEFT Methods -| | PEFT Methods | Arguments* | -|---------|----------------------------------------------------------|-----------------------------------------------------| -| ✓ | [QLoRA](https://arxiv.org/abs/2402.12354) | See *Quantize Methods* | -| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354) | `"loraplus_lr_ratio": 20.0` | -| ✓ | [DoRA](https://arxiv.org/abs/2402.09353) | `"use_dora": true` | -| ✓ | [rsLoRA](https://arxiv.org/abs/2312.03732) | `"use_rslora": true` | -| ✓ | [MoLA](https://arxiv.org/abs/2402.08562) | `"routing_strategy": "mola", "num_experts": 8` | -| ✓ | [LoRAMoE](https://arxiv.org/abs/2312.09979) | `"routing_strategy": "loramoe", "num_experts": 8` | -| ✓ | [MixLoRA](https://arxiv.org/abs/2404.15159) | See [MixLoRA](https://github.com/TUDB-Labs/MixLoRA) | +| | PEFT Methods | Arguments* | +|---------|----------------------------------------------------------|-----------------------------------------------------------| +| ✓ | [QLoRA](https://arxiv.org/abs/2402.12354) | See *Quantize Methods* | +| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354) | `"loraplus_lr_ratio": 20.0` | +| ✓ | [DoRA](https://arxiv.org/abs/2402.09353) | `"use_dora": true` | +| ✓ | [rsLoRA](https://arxiv.org/abs/2312.03732) | `"use_rslora": true` | +| ✓ | [MoLA](https://arxiv.org/abs/2402.08562) | `"routing_strategy": "mola", "num_experts": 8` | +| ✓ | [LoRAMoE](https://arxiv.org/abs/2312.09979) | `"routing_strategy": "loramoe", "num_experts": 8` | +| ✓ | [MixLoRA](https://arxiv.org/abs/2404.15159) | `"routing_strategy": "mixlora", "num_experts": 8` | +| ✓ | MixLoRA-Dynamic | `"routing_strategy": "mixlora-dynamic", "num_experts": 8` | +| ✓ | MixLoRA-Switch | `"routing_strategy": "mixlora-switch", "num_experts": 8` | *: Arguments of configuration file ### Notice of PEFT supports 1. m-LoRA supports specific optimized operators for these PEFT methods, which can effectively improve the computing performance during training, evaluation and inference. However, these operators may cause a certain degree of accuracy loss (less than 5%). You can disable the optimized operators by defining the `MLORA_EVALUATE_MODE` environment variable in advance. -2. Auxiliary Loss is not currently supported for Mo-LoRA (Mixture of LoRAs) methods other than MixLoRA +2. Auxiliary Loss is not currently supported for Mo-LoRA (Mixture of LoRAs) methods other than MixLoRA. +3. You can check detailed arguments of MixLoRA in [TUDB-Labs/MixLoRA](https://github.com/TUDB-Labs/MixLoRA). ## Supported Attention Methods diff --git a/mlora/model.py b/mlora/model.py index 2f5c907a..02585c0f 100644 --- a/mlora/model.py +++ b/mlora/model.py @@ -388,7 +388,7 @@ def _prepare_inputs( ) # prepare mask - if input_args.batch_masks_ is not None and 1 in input_args.batch_masks_: + if input_args.batch_masks_ is not None: # 2d mask is passed through the layers if isinstance(input_args.batch_masks_, torch.Tensor): attention_mask = input_args.batch_masks_.to( diff --git a/mlora/modules/__init__.py b/mlora/modules/__init__.py index 4ac6ab1e..297f16ac 100644 --- a/mlora/modules/__init__.py +++ b/mlora/modules/__init__.py @@ -52,6 +52,8 @@ # MixLoRA MoEs from .lora_moes import ( + DynamicRouterLoss, + DynamicSparseMoe, LoraMoe, MixtralRouterLoss, MixtralSparseMoe, @@ -85,6 +87,8 @@ "Linear", "MixtralRouterLoss", "MixtralSparseMoe", + "DynamicRouterLoss", + "DynamicSparseMoe", "SwitchRouterLoss", "SwitchSparseMoe", "LoraMoe", diff --git a/mlora/modules/config.py b/mlora/modules/config.py index 4b32b851..931e3ca1 100644 --- a/mlora/modules/config.py +++ b/mlora/modules/config.py @@ -202,7 +202,7 @@ def export(self) -> Dict[str, any]: return config -available_routing_strategies = ["mixlora", "mixlora-switch"] +available_routing_strategies = ["mixlora", "mixlora-dynamic", "mixlora-switch"] @dataclass @@ -219,6 +219,9 @@ class MixLoraConfig(LoraConfig): act_fn_: Optional[Union[str, torch.nn.Module]] = None # mixtral config top_k_: int = None + # dynamic config + top_p_: float = None + temperature_: float = None # switch transformers config router_z_loss_coef_: float = None expert_capacity_: int = None @@ -248,6 +251,11 @@ def check(self) -> "MixLoraConfig": ) if self.routing_strategy_ == "mixlora": assert isinstance(self.top_k_, int) and self.top_k_ > 0 + elif self.routing_strategy_ == "mixlora-dynamic": + assert ( + isinstance(self.top_p_, float) and self.top_p_ > 0 and self.top_p_ <= 1 + ) + assert isinstance(self.temperature_, float) and self.temperature_ >= 0 elif self.routing_strategy_ == "mixlora-switch": assert ( isinstance(self.router_z_loss_coef_, float) @@ -280,6 +288,11 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig": lora_config.router_init_range_ = config.get("router_init_range", 0.02) lora_config.jitter_noise_ = config.get("jitter_noise", 0.0) lora_config.top_k_ = config.get("top_k", 2) + elif lora_config.routing_strategy_ == "mixlora-dynamic": + lora_config.router_init_range_ = config.get("router_init_range", 0.02) + lora_config.jitter_noise_ = config.get("jitter_noise", 0.0) + lora_config.top_p_ = config.get("top_p", 0.8) + lora_config.temperature_ = config.get("temperature", 0.0) elif lora_config.routing_strategy_ == "mixlora-switch": lora_config.router_init_range_ = config.get("router_init_range", 1.0) lora_config.jitter_noise_ = config.get("jitter_noise", 0.01) @@ -308,6 +321,9 @@ def export(self) -> Dict[str, any]: config["act_fn"] = self.act_fn_ if self.routing_strategy_ == "mixlora": config["top_k"] = self.top_k_ + elif self.routing_strategy_ == "mixlora-dynamic": + config["top_p"] = self.top_p_ + config["temperature"] = self.temperature_ elif self.routing_strategy_ == "mixlora-switch": config["expert_capacity"] = self.expert_capacity_ config["sparse_step"] = self.sparse_step_ @@ -408,6 +424,7 @@ def expert_config(self, expert_idx: int) -> LoraConfig: routing_strategy_dict = { "mixlora": MixLoraConfig, + "mixlora-dynamic": MixLoraConfig, "mixlora-switch": MixLoraConfig, "loramoe": LoraMoeConfig, "mola": MolaConfig, diff --git a/mlora/modules/lora_moes.py b/mlora/modules/lora_moes.py index 1985231c..8d950021 100644 --- a/mlora/modules/lora_moes.py +++ b/mlora/modules/lora_moes.py @@ -8,6 +8,8 @@ from .config import LoraMoeConfig, MixLoraConfig, MolaConfig from .lora_linear import Linear from .mix_lora import ( + DynamicRouterLoss, + DynamicSparseMoe, MixtralRouterLoss, MixtralSparseMoe, SwitchRouterLoss, @@ -149,7 +151,11 @@ def forward( return residual + final_hidden_states -router_loss_dict = {"mixlora": MixtralRouterLoss, "mixlora-switch": SwitchRouterLoss} +router_loss_dict = { + "mixlora": MixtralRouterLoss, + "mixlora-dynamic": DynamicRouterLoss, + "mixlora-switch": SwitchRouterLoss, +} def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module: @@ -163,6 +169,7 @@ def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module: moe_layer_dict = { "mixlora": MixtralSparseMoe, + "mixlora-dynamic": DynamicSparseMoe, "mixlora-switch": SwitchSparseMoe, "loramoe": LoraMoe, "mola": MolaSparseMoe, diff --git a/mlora/modules/mix_lora.py b/mlora/modules/mix_lora.py index 0af464eb..e5d8009c 100644 --- a/mlora/modules/mix_lora.py +++ b/mlora/modules/mix_lora.py @@ -22,16 +22,41 @@ def _slice_tensor( return last_value +def _mixlora_compatible_forward( + ffn_layer: LLMFeedForward, + moe_name: str, + act_fn: torch.nn.Module, + expert_mask: torch.Tensor, + hidden_states: torch.Tensor, + input_dtype: torch.device, +): + final_expert_states = [] + for expert_idx in range(expert_mask.shape[0]): + _, top_x = torch.where(expert_mask[expert_idx]) + lora_name = f"moe.{moe_name}.experts.{expert_idx}" + lora_data = _slice_tensor(hidden_states, top_x, input_dtype) + final_expert_states.append( + ffn_layer._lora_forward(lora_name, act_fn, lora_data) + ) + + return final_expert_states + + +def _unpack_router_logits(gate_logits: List[torch.Tensor]): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + return concatenated_gate_logits + + def _mixtral_load_balancing_loss_func( gate_logits: List[torch.Tensor], num_experts: int, top_k: int, attention_mask: Optional[torch.Tensor] = None, ) -> float: - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat( - [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 - ) + concatenated_gate_logits = _unpack_router_logits(gate_logits) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) @@ -47,9 +72,7 @@ def _mixtral_load_balancing_loss_func( router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // ( - batch_size * sequence_length - ) + num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( @@ -58,7 +81,7 @@ def _mixtral_load_balancing_loss_func( (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) ) .reshape(-1, top_k, num_experts) - .to(compute_device) + .to(routing_weights.device) ) # Compute the percentage of tokens routed to each experts @@ -71,7 +94,7 @@ def _mixtral_load_balancing_loss_func( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) - .to(compute_device) + .to(routing_weights.device) ) # Compute the average probability of routing to these experts @@ -96,26 +119,6 @@ def forward(self, gate_logits, attention_mask) -> torch.Tensor: ) -def _mixtral_compatible_forward( - ffn_layer: LLMFeedForward, - moe_name: str, - act_fn: torch.nn.Module, - expert_mask: torch.Tensor, - hidden_states: torch.Tensor, - input_dtype: torch.device, -): - final_expert_states = [] - for expert_idx in range(expert_mask.shape[0]): - _, top_x = torch.where(expert_mask[expert_idx]) - lora_name = f"moe.{moe_name}.experts.{expert_idx}" - lora_data = _slice_tensor(hidden_states, top_x, input_dtype) - final_expert_states.append( - ffn_layer._lora_forward(lora_name, act_fn, lora_data) - ) - - return final_expert_states - - class MixtralSparseMoe(LLMMoeBlock): def __init__( self, @@ -227,7 +230,217 @@ def forward( self.adapter_name_, self.act_, expert_mask, hidden_states, input_dtype ) else: - expert_states = _mixtral_compatible_forward( + expert_states = _mixlora_compatible_forward( + ffn_layer, + self.adapter_name_, + self.act_, + expert_mask, + hidden_states, + input_dtype, + ) + + # Unpack + for expert_idx in range(self.experts_): + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_hidden_states = ( + expert_states[expert_idx] * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(self.dtype_) + ) + + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ).to(input_dtype) + + return final_hidden_states, router_logits + + +def _dynamic_top_p(router_logits: torch.Tensor, top_p: float, temperature: float = 0.0): + if temperature > 0.0: + router_logits = router_logits / temperature + sorted_logits, sorted_indices = torch.sort(router_logits, dim=-1, descending=True) + cumulative_probs = sorted_logits.cumsum(dim=-1) + expert_mask = cumulative_probs > top_p + threshold_indices = expert_mask.long().argmax(dim=-1) + threshold_mask = torch.nn.functional.one_hot( + threshold_indices, num_classes=sorted_indices.size(-1) + ).bool() + expert_mask = expert_mask & ~threshold_mask + sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0) + sorted_indices = sorted_indices.masked_fill(expert_mask, -1) + return sorted_logits, sorted_indices + + +def _dynamic_load_balancing_loss_func( + routing_weights: torch.Tensor, + num_experts: int, + top_p: float, + temperature: float, +) -> float: + _, selected_experts = _dynamic_top_p(routing_weights, top_p, temperature) + + expert_mask = torch.empty( + (num_experts, num_experts, routing_weights.size(0)), + dtype=routing_weights.dtype, + device=routing_weights.device, + ) + + for expert_idx in range(num_experts): + expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1) + + expert_mask = expert_mask.permute(2, 1, 0) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class DynamicRouterLoss(torch.nn.Module): + def __init__(self, config: MixLoraConfig) -> None: + super().__init__() + self.aux_loss_coef = config.router_aux_loss_coef_ + self.experts = config.num_experts_ + self.top_p = config.top_p_ + self.temperature = config.temperature_ + + def forward(self, gate_logits, attention_mask) -> torch.Tensor: + concatenated_gate_logits = _unpack_router_logits(gate_logits) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + return self.aux_loss_coef * _dynamic_load_balancing_loss_func( + routing_weights, + self.experts, + self.top_p, + self.temperature, + ) + + +class DynamicSparseMoe(LLMMoeBlock): + def __init__( + self, + in_features: int, + device: torch.device, + config: MixLoraConfig, + gate: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.adapter_name_: str = config.adapter_name + self.dtype_: torch.dtype = torch.float32 + self.gate_ = torch.nn.Linear( + in_features, + config.num_experts_, + bias=False, + device=device, + dtype=self.dtype_, + ) + self.act_ = ( + ACT2FN[config.act_fn_] + if isinstance(config.act_fn_, str) + else config.act_fn_ + ) + self.experts_: int = config.num_experts_ + self.top_p_: float = config.top_p_ + self.temperature_: float = config.temperature_ + self.jitter_noise_: float = config.jitter_noise_ + self.router_profile_: bool = False + self.profiler_: List[int] = None + + if gate is None: + torch.nn.init.normal_( + self.gate_.weight, + mean=0.0, + std=config.router_init_range_, + ) + else: + with torch.no_grad(): + self.gate_.weight.copy_(gate) + + def state_dict(self) -> Dict[str, torch.nn.Module]: + return {"gate": self.gate_.weight} + + def _profiling( + self, batch_size: int, sequence_length: int, selected_experts: torch.Tensor + ) -> None: + if not self.router_profile_: + return + + router_statistic_ = list(0 for _ in range(self.experts_)) + for selected in selected_experts.tolist(): + for idx in selected: + router_statistic_[idx] += 1 + + if self.profiler_ is None: + self.profiler_ = list(0 for _ in range(self.experts_)) + for idx in range(self.experts_): + self.profiler_[idx] = ( + router_statistic_[idx] / batch_size + ) / sequence_length + else: + for idx in range(self.experts_): + pressure = (router_statistic_[idx] / batch_size) / sequence_length + self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2 + + def forward( + self, + hidden_states: torch.Tensor, + ffn_layer: LLMFeedForward, + input_args: LLMModelInput, + ) -> Tuple: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + if not input_args.inference_mode_ and self.jitter_noise_ > 0: + # Multiply the token inputs by the uniform distribution - adding some noise + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_ + ) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate_(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=self.dtype_) + routing_weights, selected_experts = _dynamic_top_p( + routing_weights, self.top_p_, self.temperature_ + ) + + self._profiling(batch_size, sequence_length, selected_experts) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=self.dtype_, + device=hidden_states.device, + ) + + expert_mask = torch.empty( + (self.experts_, self.experts_, batch_size * sequence_length), + dtype=self.dtype_, + device=hidden_states.device, + ) + + for expert_idx in range(self.experts_): + expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1) + + # Perform the computation on each expert + if input_args.efficient_operator_ and hasattr(ffn_layer, "_mixlora_forward"): + expert_states = ffn_layer._mixlora_forward( + self.adapter_name_, self.act_, expert_mask, hidden_states, input_dtype + ) + else: + expert_states = _mixlora_compatible_forward( ffn_layer, self.adapter_name_, self.act_, diff --git a/mlora/tokenizer.py b/mlora/tokenizer.py index 5a1f6585..749d1ef0 100644 --- a/mlora/tokenizer.py +++ b/mlora/tokenizer.py @@ -48,9 +48,10 @@ def encode( def decode(self, data: Tokens) -> str: return self.tokenizer.decode(data) - # get the mask from tokens + # Get the mask from tokens + # https://huggingface.co/docs/transformers/glossary#attention-mask # example: tokens is [2, 3, pad, pad, 4, 5] - # output is [0, 0, 1, 1, 0, 0] + # output is [1, 1, 0, 0, 1, 1] def mask_from(self, tokens: Tokens) -> Masks: mask_tokens = [self.pad_id_] - return [int(tok in mask_tokens) for tok in tokens] + return [int(tok not in mask_tokens) for tok in tokens] diff --git a/pyproject.toml b/pyproject.toml index 7d9df04d..caa5cf97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mlora" -version = "0.5.1" +version = "0.5.2" description = "An Efficient Factory to Build Multiple LoRA Adapters" readme = "README.md" requires-python = ">=3.8" diff --git a/templates/mixlora_dynamic.json b/templates/mixlora_dynamic.json new file mode 100644 index 00000000..fc402b31 --- /dev/null +++ b/templates/mixlora_dynamic.json @@ -0,0 +1,37 @@ +{ + "cutoff_len": 512, + "save_step": 1000, + "train_lora_candidate_num": 2, + "train_lora_simultaneously_num": 2, + "train_strategy": "optim", + "lora": [ + { + "name": "mixlora", + "task_name": "", + "optim": "adamw", + "scheduler_type": "constant", + "warmup_steps": 0, + "lr": 2e-4, + "batch_size": 16, + "micro_batch_size": 8, + "evaluate_batch_size": 16, + "num_epochs": 2, + "r": 16, + "lora_alpha": 32, + "lora_dropout": 0.05, + "target_modules": { + "q_proj": true, + "k_proj": true, + "v_proj": true, + "o_proj": true, + "gate_proj": true, + "down_proj": true, + "up_proj": true + }, + "routing_strategy": "mixlora-dynamic", + "num_experts": 8, + "top_p": 0.8, + "group_by_length": false + } + ] +} \ No newline at end of file diff --git a/templates/mixlora_dynamic_glm.json b/templates/mixlora_dynamic_glm.json new file mode 100644 index 00000000..88ea4549 --- /dev/null +++ b/templates/mixlora_dynamic_glm.json @@ -0,0 +1,34 @@ +{ + "cutoff_len": 512, + "save_step": 1000, + "train_lora_candidate_num": 2, + "train_lora_simultaneously_num": 2, + "train_strategy": "optim", + "lora": [ + { + "name": "mixlora", + "task_name": "", + "optim": "adamw", + "scheduler_type": "constant", + "warmup_steps": 0, + "lr": 2e-4, + "batch_size": 16, + "micro_batch_size": 8, + "evaluate_batch_size": 16, + "num_epochs": 2, + "r": 14, + "lora_alpha": 28, + "lora_dropout": 0.05, + "target_modules": { + "qkv_proj": true, + "dense": true, + "dense_h_to_4h": true, + "dense_4h_to_h": true + }, + "routing_strategy": "mixlora-dynamic", + "num_experts": 8, + "top_p": 0.8, + "group_by_length": false + } + ] +} \ No newline at end of file diff --git a/templates/mixlora_dynamic_phi.json b/templates/mixlora_dynamic_phi.json new file mode 100644 index 00000000..0d9791f2 --- /dev/null +++ b/templates/mixlora_dynamic_phi.json @@ -0,0 +1,36 @@ +{ + "cutoff_len": 512, + "save_step": 1000, + "train_lora_candidate_num": 2, + "train_lora_simultaneously_num": 2, + "train_strategy": "optim", + "lora": [ + { + "name": "mixlora", + "task_name": "", + "optim": "adamw", + "scheduler_type": "constant", + "warmup_steps": 0, + "lr": 2e-4, + "batch_size": 16, + "micro_batch_size": 8, + "evaluate_batch_size": 16, + "num_epochs": 2, + "r": 16, + "lora_alpha": 32, + "lora_dropout": 0.05, + "target_modules": { + "q_proj": true, + "k_proj": true, + "v_proj": true, + "dense": true, + "fc1": true, + "fc2": true + }, + "routing_strategy": "mixlora-dynamic", + "num_experts": 8, + "top_p": 0.8, + "group_by_length": false + } + ] +} \ No newline at end of file diff --git a/templates/mixlora_dynamic_phi3.json b/templates/mixlora_dynamic_phi3.json new file mode 100644 index 00000000..5b1edce6 --- /dev/null +++ b/templates/mixlora_dynamic_phi3.json @@ -0,0 +1,34 @@ +{ + "cutoff_len": 512, + "save_step": 1000, + "train_lora_candidate_num": 2, + "train_lora_simultaneously_num": 2, + "train_strategy": "optim", + "lora": [ + { + "name": "mixlora", + "task_name": "", + "optim": "adamw", + "scheduler_type": "constant", + "warmup_steps": 0, + "lr": 2e-4, + "batch_size": 16, + "micro_batch_size": 8, + "evaluate_batch_size": 16, + "num_epochs": 2, + "r": 16, + "lora_alpha": 32, + "lora_dropout": 0.05, + "target_modules": { + "qkv_proj": true, + "o_proj": true, + "gate_up_proj": true, + "down_proj": true + }, + "routing_strategy": "mixlora-dynamic", + "num_experts": 8, + "top_p": 0.8, + "group_by_length": false + } + ] +} \ No newline at end of file