📢
vanilla-llama
is a plain-pytorch implementation ofLLaMA
with minimal differences with respect to the original Facebook's implementation. You can runvanilla-llama
on 1, 2, 4, 8 or 100 GPUs
Couldn't be more easy to use 🔥
Comes with an inference server included 🔋
from inference import LLaMAInference
llama = LLaMAInference(llama_path, "65B")
print(llama.generate(["My name is Federico"]))
- Easy to use and fine-tune 🔥
- Uses 🤗 accelerate to distribute the model on all available GPUs
- Comes with batteries included🔋
- Nice one line loading and generation 😎
Stop generation on specific tokens (13
is the new-line token)
llama.generate(["Chat:\nHuman: Hi i am an human\nAI:"], stop_ids=[13])
Stop generation on specific texts
llama.generate(["Question: is the sky blue?\nAnswer:"], stop_words=["Question"])
Batch generation
llama.generate(["My name is Federico", "My name is Zuck"])
Repetition Penalty
llama.generate(["This is a list of awesome things:\n"], repetition_penalty=(1.0 / 0.85))
Install server requirements
pip install -r server_requirements.txt
Run the server
python server.py --llama-path <CONVERTED-WEIGHTS-PATH> --model <MODEL>
Test it!
curl -X GET http://localhost:3000/generate -H "Content-Type: application/json" -d '{"prompt": "REST server are very useful becouse"}'
Clone this repository
git clone https://github.com/galatolofederico/vanilla-llama.git
cd vanilla-llama
Install the requirements
python3 -m venv env
. ./env/bin/activate
pip install -r requirements.txt
To convert LLaMA weights to a plain pytorch state-dict run
python convert.py --llama-path <ORIGINAL-LLAMA-WEIGHTS> --model <MODEL> --output-path <CONVERTED-WEIGHTS-PATH>
Run the provided example
python example.py --llama-path <CONVERTED-WEIGHTS-PATH> --model <MODEL>