Skip to content

Commit

Permalink
Adding Qwen2.5 (#1834)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com>
  • Loading branch information
ysjprojects and Andrei-Aksionov authored Nov 27, 2024
1 parent 22528bf commit ff8b1b6
Show file tree
Hide file tree
Showing 13 changed files with 591 additions and 5 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ Every model is written from scratch to maximize performance and remove layers of
| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
Expand Down
2 changes: 1 addition & 1 deletion litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
Expand Down
292 changes: 292 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Config:
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
attn_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
Expand Down Expand Up @@ -1704,4 +1705,295 @@ def norm_class(self) -> Type:

configs.extend(llama_2_function_calling)

##########
# Qwen2.5
##########
qwen_2_5 = [
# https://huggingface.co/Qwen/Qwen2.5-0.5B/blob/main/config.json
dict(
name="Qwen2.5-0.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-0.5B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=24,
n_head=14,
n_embd=896,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=4864,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-1.5B/blob/main/config.json
dict(
name="Qwen2.5-1.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-1.5B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=12,
n_embd=1536,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8960,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/config.json
dict(
name="Qwen2.5-3B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-3B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=16,
n_embd=2048,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json
dict(
name="Qwen2.5-7B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-7B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=28,
n_head=28,
n_embd=3584,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=18944,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-14B/blob/main/config.json
dict(
name="Qwen2.5-14B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-14B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=48,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
norm_eps=1e-5,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-32B/blob/main/config.json
dict(
name="Qwen2.5-32B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-32B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-72B/blob/main/config.json
dict(
name="Qwen2.5-72B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-72B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=29568,
norm_eps=1e-5,
rope_base=1000000
),
]

qwen_2_5_coder = [
# https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B/blob/main/config.json
dict(
name="Qwen2.5-Coder-0.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-0.5B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=24,
n_head=14,
n_embd=896,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=4864,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B/blob/main/config.json
dict(
name="Qwen2.5-Coder-1.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-1.5B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=12,
n_embd=1536,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8960,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-3B/blob/main/config.json
dict(
name="Qwen2.5-Coder-3B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-3B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=16,
n_embd=2048,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-7B/blob/main/config.json
dict(
name="Qwen2.5-Coder-7B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-7B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=28,
n_head=28,
n_embd=3584,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=18944,
norm_eps=1e-6,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-14B/blob/main/config.json
dict(
name="Qwen2.5-Coder-14B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-14B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=48,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
norm_eps=1e-5,
rope_base=1000000
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-32B/blob/main/config.json
dict(
name="Qwen2.5-Coder-32B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-32B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000
),
]

qwen_2_5.extend(qwen_2_5_coder)

for c in qwen_2_5:
for kind in ("", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)

name_to_config = {config["name"]: config for config in configs}
2 changes: 1 addition & 1 deletion litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
enable_lora=(config.lora_query, config.lora_key, config.lora_value),
bias=config.bias,
bias=config.bias or config.attn_bias,
# for MQA/GQA support
head_size=config.head_size,
n_head=config.n_head,
Expand Down
2 changes: 1 addition & 1 deletion litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias or config.attn_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
Expand Down
9 changes: 9 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
class OLMo(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n"


class Qwen2_5(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


# Maps prompt style names to PromptStyle classes
Expand All @@ -304,6 +310,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"gemma": Gemma,
"llama3": Llama3,
"olmo": OLMo,
"qwen2.5": Qwen2_5,
}


Expand Down Expand Up @@ -342,6 +349,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Gemma()
if re.search(r"OLMo.*-hf", model_name):
return OLMo()
if re.search(r"Qwen2\.5-.*", model_name):
return Qwen2_5()
return Default()


Expand Down
Loading

0 comments on commit ff8b1b6

Please sign in to comment.