Skip to content

Commit

Permalink
ENH: Add default target layers for gemma2 architecture (#2078)
Browse files Browse the repository at this point in the history
Google's gemma 2 models have a slightly different architecture than
gemma 1 and thus a different model_type attribute. This PR adds default
target_layer for gemma 2 that correspond to the default target_layer of
gemma 1.

LayerNorm tuning adds one more LN layer.
  • Loading branch information
BenjaminBossan authored Sep 23, 2024
1 parent af275d2 commit 5efeba1
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"mistral": ["input_layernorm", "post_attention_layernorm", "norm"],
"phi": ["input_layernorm", "final_layernorm"],
"gemma": ["input_layernorm", "post_attention_layernorm", "norm"],
"gemma2": [
"input_layernorm",
"post_attention_layernorm",
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
"norm",
],
"qwen2": ["post_attention_layernorm"],
}

Expand Down Expand Up @@ -107,6 +114,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"stablelm": ["q_proj", "v_proj"],
"phi": ["q_proj", "v_proj", "fc1", "fc2"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand All @@ -133,6 +141,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"falcon": ["query_key_value", "dense_4h_to_h"],
"phi": ["q_proj", "v_proj", "fc2"],
"gemma": ["q_proj", "v_proj", "down_proj"],
"gemma2": ["q_proj", "v_proj", "down_proj"],
"qwen2": ["q_proj", "v_proj", "down_proj"],
}

Expand All @@ -159,6 +168,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"falcon": ["dense_4h_to_h"],
"phi": ["fc2"],
"gemma": ["down_proj"],
"gemma2": ["down_proj"],
"qwen2": ["down_proj"],
}

Expand Down Expand Up @@ -216,6 +226,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"stablelm": ["q_proj", "v_proj"],
"phi": ["q_proj", "v_proj"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand Down Expand Up @@ -250,6 +261,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"stablelm": ["q_proj", "v_proj"],
"phi": ["q_proj", "v_proj", "fc1", "fc2"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand Down

0 comments on commit 5efeba1

Please sign in to comment.