-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Huggingface 版必须要安装flash attention? #70
Comments
可以使用我们开源的镜像,有安装flash_attn |
flash_attn 不支持V100 GPU。 我手工关掉了Flash attention,模型可以跑了,但目前发现无法复现megatron版的输出 模型:Yuan 2.0 2B hf
输出:
|
我们做了如下测试,输出结果是正常的。 输入: 编写一个 Python 函数,它接受一个字符串作为参数,并返回该字符串的反转版本。示例:string_reverse('hello') olleh 代码如下: 输出: <sep> ```python
def string_reverse(string):
return string[::-1]
```<eod> |
读取模型代码: import torch, transformers
from transformers import AutoModelForCausalLM,AutoTokenizer,LlamaTokenizer
print("Creat tokenizer...")
tokenizer = LlamaTokenizer.from_pretrained(yuan_path)
tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
print("Creat model...")
model = AutoModelForCausalLM.from_pretrained(yuan_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to('cuda:1') 推理代码: question = """编写一个 Python 函数,它接受一个字符串作为参数,并返回该字符串的反转版本。
示例:
>>> string_reverse('hello')
olleh
代码如下:<sep>
```python
"""
inputs = tokenizer(question, return_tensors="pt")["input_ids"].to("cuda:1")
outputs = model.generate(inputs,do_sample=False,max_length=200)
print(tokenizer.decode(outputs[0])) 输出:
模型没有去实现string_reverse函数,只是写了一些测试用例 |
请尝试以下输入:
另外,请务必使用贪婪搜索(greedy decoding)生成代码,可令temperature=1,top_k=1。 |
@Armod-I 请问问题是否已解决? |
请问您是如何手动关掉flash_attn的呢,我想用CPU跑这个模型,我尝试按huggingface上面的https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/main/README.md
但是还是会报错ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run |
报错如下:
ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run
pip install flash_attn
其它模型可以不用flash attention的
The text was updated successfully, but these errors were encountered: