-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support DeepSeek Coder Model #278
Comments
The reason this happens is that their repository doesn't have a special tokens map file. I added it here: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-base/discussions/1. You should be able to clone that branch and then load the tokenizer from {:local, "/path/to/local/model"} and the generation will work The special tokens (padding, eos, bos, etc.) are in the |
@seanmor5 apparently they want to migrate off of these files
I will later revisit the loading logic. @jonastemplestein meanwhile you can load from the PR commit directly with |
Thank you so much you two! <3 |
Hey folks, I've just had a chance to play with this and keep getting garbage output from the model compared to what I get using the HF transformers library For example, running this in livebook: # setup cell
Mix.install(
[
{:kino_bumblebee, "~> 0.4.0"},
{:exla, ">= 0.0.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
# subsequent code cell
repo = {:hf, "deepseek-ai/deepseek-coder-1.3b-base", revision: "7abe797c62ede1b47c4d00a17ed006be9659d657"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
compile: [batch_size: 1, sequence_length: 1028],
stream: false,
defn_options: [compiler: EXLA, lazy_transfers: :never]
)
# # Should be supervised
Kino.start_child({Nx.Serving, name: Deepseek, serving: serving})
prompt = "<|fim▁begin|>def quick_sort(arr):\n <|fim▁hole|> \n <|fim▁end|> "
Nx.Serving.batched_run(Deepseek, prompt) results in this output:
In comparison, running the same model in HF transformers in python gives me something completely different and sensible: from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-1.3b-base", trust_remote_code=True).cuda()
input_text = "<|fim▁begin|>def quick_sort(arr):\n <|fim▁hole|> \n <|fim▁end|> "
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=328)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_text):]) Output:
Any ideas how to further debug this? 🤔 Any help is much appreciated I'm on an M2 Pro in case that is relevant, but I also saw the same behaviour when I briefly ran this on an A100 with cuda |
Also, slightly related question: Do you think I can use the fine-tuning process described here to fine-tune this model directly from livebook? Can you think of any reasons why that might not work? I think it'd be a very cool demo to train an elixir code completion model for livebook in livebook :) |
@jonastemplestein I tracked down the output difference and it turns out this checkpoint uses scaling for rotary embedding, which we don't have yet. I inlined scaling just to check and the output looks the same, except there is a bunch of newlines at the end (which makes the generation run for way too long). I will send a PR once I figure it out and get everything in sync :) |
Amazing, thank you so much! Regarding the excessive newlines, I've found that one way to reduce that in python land is to set a higher "repeat penalty" for token generation. Is there such a parameter in bumblebee? I saw there is Along similar lines, are there any plans to support a temperature parameter? Or can I somehow simulate the way temperature works using the sampling strategy? |
@jonastemplestein with #285 it should match the Python implementation, you also need to load the tokenizer from this revision until upstream PR is merged: {:ok, tokenizer} =
Bumblebee.load_tokenizer(
{:hf, "deepseek-ai/deepseek-coder-1.3b-base",
revision: "e94f2b11bc28abbd67ecadfaad058c30b24a589f"}
) We don't have repetition penalty, only We don't have temperature either, but supporting that is trivial, I will open up an issue and implement soon. |
Amazing, thank you so much! If you'd like, you can see the model in action in this livebook instance, which has rudimentary code completion: https://livebookjonas.fly.dev/ (password elixircopilot) The livebook instance powers down after some inactivity and then takes 30 seconds or so to come back up. If it just booted, you can open the starred notebook called Livebook Copilot Playground to play around. Two questions
Thanks again for your help with this, super amazing <3 |
Done. Note that the tokenizer should be the same, so you can load the tokenizer as above and model from any other checkpoint :) |
That's how text completion models usually work, we pass the input sequence, they generate tokens one by one, which we effectively keep appending to the input sentence for subsequent inferences. There are also encoder-decoder models, where the input first goes through an encoder and is treated as the context for generation, separate from the generation sequence (e.g. BART). That said, we should have an option to strip the tokens, tracked by #247 :) |
Hey folks, I'm trying to use the deepseek-coder-1.3b-base model with bumblebee. I was delighted to find that the model, tokenizer and generation_config all load. But when trying to run inference I get the following error that's a bit hard for me to debug:
I'm using bumblebee 0.4.2
Here's the model spec
And here's the tokenizer
It looks like the vocab size is not correct in the model spec, for example.
I think the tokenizer uses the correct vocabulary, because I can run this:
and it correctly returns <|fim▁hole|> , which is a deepseek specific token
Would be amazing if this model was supported, as deepseek-coder actually seems to be pretty good at elixir out of the box 🙇
Thank you so much for your help!
The text was updated successfully, but these errors were encountered: