Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

May 2024 Prelim #447

Merged
merged 76 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
64a7d27
Fix prompt
danielhanchen Apr 19, 2024
495f1da
Merge branch 'main' into nightly
danielhanchen Apr 20, 2024
656ab22
Update chat_templates.py
danielhanchen Apr 20, 2024
c4f2f54
fix_untrained_tokens
danielhanchen Apr 20, 2024
87b4bb9
Update llama.py
danielhanchen Apr 21, 2024
abd192f
add tokens
danielhanchen Apr 21, 2024
868351b
Update _utils.py
danielhanchen Apr 21, 2024
f29a3e7
Update tokenizer_utils.py
danielhanchen Apr 21, 2024
2573474
Update llama.py
danielhanchen Apr 21, 2024
bfb32a3
Update llama.py
danielhanchen Apr 21, 2024
40a6d00
Update llama.py
danielhanchen Apr 21, 2024
140a0b0
Update llama.py
danielhanchen Apr 21, 2024
88435a8
pad_token
danielhanchen Apr 21, 2024
24790e2
Update chat_templates.py
danielhanchen Apr 21, 2024
1464f7d
Update chat_templates.py
danielhanchen Apr 21, 2024
df069c5
tokenizer
danielhanchen Apr 21, 2024
eb00fb7
Update save.py
danielhanchen Apr 21, 2024
805f890
Update chat_templates.py
danielhanchen Apr 21, 2024
80be6ff
Update chat_templates.py
danielhanchen Apr 21, 2024
92723ba
Merge branch 'main' into nightly
danielhanchen Apr 22, 2024
2e62a69
patch tokenizer padding
danielhanchen Apr 22, 2024
b0678d6
Update tokenizer_utils.py
danielhanchen Apr 22, 2024
f85ef9c
Update save.py
danielhanchen Apr 23, 2024
d2f10a0
Fix: loading models with resized vocabulary (#377)
oKatanaaa Apr 24, 2024
f5fa654
GGUF fix
danielhanchen Apr 28, 2024
8325e05
Readme (#390)
danielhanchen Apr 28, 2024
13b1ae6
Update README.md
danielhanchen Apr 28, 2024
5069a7d
Delete .gitignore
danielhanchen Apr 28, 2024
1ba3379
Merge branch 'main' into nightly
danielhanchen Apr 29, 2024
7c9c3f5
Phi-3
danielhanchen Apr 29, 2024
7b696ee
Update README.md
danielhanchen Apr 29, 2024
48334f7
Update README.md
danielhanchen Apr 29, 2024
3665c0b
Update README.md
danielhanchen Apr 29, 2024
0f9e073
Update README.md
danielhanchen Apr 29, 2024
eb135d8
Update README.md
danielhanchen Apr 29, 2024
56e2674
Update README.md
danielhanchen Apr 29, 2024
b091a0b
Update README.md
danielhanchen Apr 29, 2024
18533ab
Update README.md
danielhanchen Apr 29, 2024
3e84338
Update README.md
danielhanchen Apr 29, 2024
d8feef5
Update README.md
danielhanchen Apr 29, 2024
392c034
Update README.md
danielhanchen Apr 29, 2024
df6fb52
Update README.md
danielhanchen Apr 29, 2024
99ed47a
Update README.md
danielhanchen Apr 29, 2024
7fae556
Update README.md
danielhanchen Apr 29, 2024
000d050
Update README.md
danielhanchen Apr 29, 2024
27f88f0
Update README.md
danielhanchen Apr 29, 2024
affbba1
Update README.md
danielhanchen Apr 29, 2024
14f104a
Update README.md
danielhanchen Apr 29, 2024
e040d18
Fix reserved tokens
danielhanchen May 4, 2024
fb10081
Merge branch 'main' into nightly
danielhanchen May 4, 2024
f53944a
Update save.py
danielhanchen May 4, 2024
70b41d1
Update tokenizer_utils.py
danielhanchen May 4, 2024
1b1b931
Update tokenizer_utils.py
danielhanchen May 4, 2024
61edc3c
Update tokenizer_utils.py
danielhanchen May 4, 2024
73df3ee
Update tokenizer_utils.py
danielhanchen May 4, 2024
15d7898
Update tokenizer_utils.py
danielhanchen May 4, 2024
84418a9
Merge branch 'main' into nightly
danielhanchen May 5, 2024
76ed0a4
Update chat_templates.py
danielhanchen May 6, 2024
dfec8dd
Update save.py
danielhanchen May 7, 2024
73af5d1
Update _utils.py
danielhanchen May 7, 2024
9c7d9a7
Update chat_templates.py
danielhanchen May 7, 2024
f1f8db3
Merge branch 'main' into nightly
danielhanchen May 8, 2024
7c53652
Adds dependencies and extras for torch 2.3.0 with new xformers versio…
nathan-az May 10, 2024
cf83fe3
Support Qwen2 (#428)
yangjianxin1 May 10, 2024
f7dab30
Update save.py
danielhanchen May 10, 2024
10e01f4
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen May 10, 2024
6c9fcc9
Update save.py
danielhanchen May 10, 2024
f1350ca
Update _utils.py
danielhanchen May 10, 2024
73b941d
Update save.py
danielhanchen May 11, 2024
7d502d7
Update save.py
danielhanchen May 11, 2024
d1d47b3
Update save.py
danielhanchen May 11, 2024
f16d7d7
test_hf_gguf_equivalence
danielhanchen May 11, 2024
01284f4
Update chat_templates.py
danielhanchen May 12, 2024
4f1e6fb
Update chat_templates.py
danielhanchen May 12, 2024
36cfcf4
--pad-vocab
danielhanchen May 12, 2024
6b2ee16
Update tokenizer_utils.py
danielhanchen May 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text.

## 🦥 Unsloth.ai News
- 📣 NEW! Qwen1.5-7B, Qwen1.5-14B, Qwen1.5-32B, Qwen1.5-72B now work, courtesy of Firefly's PR [#428](https://github.com/unslothai/unsloth/pull/428)
- 📣 NEW! [Llama-3 8b](https://colab.research.google.com/drive/135ced7oHytdxu3N2DNe1Z0kqjyYIkDXp?usp=sharing) now works! Llama-3 70b also works (change the model name in the notebook).
- 📣 NEW! [ORPO support](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) is here!
- 📣 NEW! [Phi-3 3.8b support](https://colab.research.google.com/drive/1NvkBmkHfucGO3Ve9s1NKZvMNlw5p83ym?usp=sharing) is here!
Expand Down Expand Up @@ -159,7 +160,14 @@ pip install --no-deps packaging ninja einops flash-attn xformers trl peft accele
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps xformers trl peft accelerate bitsandbytes
```
7. To troubleshoot installs try the below (all must succeed). Xformers should mostly all be available.
7. For Pytorch 2.3.0: Use the `"ampere"` path for newer RTX 30xx GPUs or higher.
```bash
pip install "unsloth[cu118-torch230] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
```
8. To troubleshoot installs try the below (all must succeed). Xformers should mostly all be available.
```bash
nvcc
python -m xformers.info
Expand Down
37 changes: 37 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ cu121onlytorch220 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
]
cu118onlytorch230 = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.26.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.26.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.26.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
]
cu121onlytorch230 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.26.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.26.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.26.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
]

cu118 = [
"unsloth[huggingface]",
"bitsandbytes",
Expand Down Expand Up @@ -126,6 +137,16 @@ cu121-torch220 = [
"bitsandbytes",
"unsloth[cu121onlytorch220]",
]
cu118-torch230 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch230]",
]
cu121-torch230 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch230]",
]
kaggle = [
"unsloth[huggingface]",
]
Expand Down Expand Up @@ -238,6 +259,22 @@ cu121-ampere-torch220 = [
"ninja",
"flash-attn",
]
cu118-ampere-torch230 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch230]",
"packaging",
"ninja",
"flash-attn",
]
cu121-ampere-torch230 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch230]",
"packaging",
"ninja",
"flash-attn",
]

[project.urls]
homepage = "http://www.unsloth.ai"
Expand Down
88 changes: 80 additions & 8 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
__all__ = [
"get_chat_template",
"test_chat_templates",
"test_hf_gguf_equivalence",
]

from transformers import StoppingCriteria, StoppingCriteriaList
Expand Down Expand Up @@ -270,12 +271,11 @@
phi3_template = \
"{{ bos_token }}"\
"{% for message in messages %}"\
"{% if (message['role'] == 'user') %}"\
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"\
"{% elif (message['role'] == 'assistant') %}"\
"{{message['content'] + '<|end|>' + '\n'}}"\
"{% endif %}"\
"{% endfor %}"
"{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token,)

Expand Down Expand Up @@ -613,8 +613,80 @@ def test_chat_templates():
# Phi-3
template = phi3_template
correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
pass


def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
"""
Carefully checks the output of GGUF's tokenization and HF.
Can catch all tokenization bugs.
"""
import subprocess
import re
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": " But 2+2 is equal to 5. "},
{"role": "assistant", "content": "No I'm sure its 4."},
{"role": "user", "content": " No it's 100% 5! "},
]

prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}""".format(
"Describe the city given eloquently.", # instruction
"The lost city of Atlantis.", # input
"", # output - leave this blank for generation!
)
prompts = [ prompt, ]

if tokenizer.chat_template is not None:
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
prompt = prompt.replace("'", "") # Subprocess does not like ''
prompts.append(prompts)
pass

for prompt in prompts:
command = f"./llama.cpp/main -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
f"--check-tensors -p '{prompt}'"

datas = []
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
datas.append(line.decode("utf-8", errors = "replace"))
pass
gguf_tokens = "".join(datas)

# Now extract GGUF tokenization attempt
gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
input_ids = tokenizer(prompt).input_ids
tokens = tokenizer.batch_decode(input_ids)
hf_tokenized = list(zip(input_ids, tokens))
print(gguf_tokenized[:5])

# Compare to Huggingface
for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
if (hf_token[0] != gguf_token[0]):
print("Failed GGUF != HF at", j)
print("HF =", hf_token)
print("GGUF =", gguf_token)
print(hf_tokenized[:j+1])
print(gguf_tokenized[:j+1])
print(gguf_tokens)
raise RuntimeError("Failed comparing GGUF to HF.")
pass
pass
return True
pass
7 changes: 4 additions & 3 deletions unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .mistral import FastMistralModel
from .dpo import PatchDPOTrainer
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import os
import psutil

__version__ = "2024.4"
__version__ = "2024.5"

# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
major_version, minor_version = torch.cuda.get_device_capability()
Expand Down
1 change: 1 addition & 0 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ def patch_peft_model(

if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
Expand Down
3 changes: 3 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
Expand Down Expand Up @@ -119,6 +120,8 @@ def from_pretrained(
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
else:
raise NotImplementedError(
f"Unsloth: {model_name} not supported yet!\n"\
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def from_pretrained(
# Mistral does NOT support RoPE Scaling sadly so we have to error out.
if max_seq_length > model_max_seq_length:
raise RuntimeError(
"Unsloth: Unfortunately Mistral type models do not support RoPE scaling!\n"\
f"Unsloth: Unfortunately {model_patcher.__name__[4:-5]} type models do not support RoPE scaling!\n"\
f"The maximum sequence length supported is {model_max_seq_length}.",
)
pass
Expand Down
91 changes: 91 additions & 0 deletions unsloth/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .llama import *
from .mistral import FastMistralModel
import os
from ._utils import __version__

from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2Model,
Qwen2ForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2SdpaAttention,
Qwen2FlashAttention2,
)
except:
Qwen2SdpaAttention = Qwen2Attention
Qwen2FlashAttention2 = Qwen2Attention
pass


class FastQwen2Model(FastLlamaModel):

@staticmethod
def pre_patch():
Qwen2Attention .forward = LlamaAttention_fast_forward
Qwen2SdpaAttention .forward = LlamaAttention_fast_forward
Qwen2FlashAttention2.forward = LlamaAttention_fast_forward
Qwen2DecoderLayer .forward = LlamaDecoderLayer_fast_forward
Qwen2Model .forward = LlamaModel_fast_forward
Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.qwen2.modeling_qwen2
transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding
return
pass


@staticmethod
def from_pretrained(
model_name = "Qwen/Qwen1.5-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Qwen2 does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastMistralModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen2Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass
Loading