Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StarCoder/SantaCoder example #146

Merged
merged 14 commits into from
May 13, 2023
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ add_subdirectory(whisper)
add_subdirectory(mnist)
add_subdirectory(stablelm)
add_subdirectory(dolly-v2)
add_subdirectory(starcoder)
13 changes: 13 additions & 0 deletions examples/starcoder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# starcoder

set(TEST_TARGET starcoder)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

#
# starcoder-quantize

set(TEST_TARGET starcoder-quantize)
add_executable(${TEST_TARGET} quantize.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
112 changes: 112 additions & 0 deletions examples/starcoder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# 💫 StarCoder

This is a C++ example running 💫 StarCoder inference using the [ggml](https://github.com/ggerganov/ggml) library.

The program runs on the CPU - no video card is required.

The example supports the following 💫 StarCoder models:

- `bigcode/starcoder`
- `bigcode/gpt_bigcode-santacoder` aka the smol StarCoder

Sample performance on MacBook M1 Pro:

TODO


Sample output:

```
$ ./bin/starcoder -h
usage: ./bin/starcoder [options]

options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 8)
-p PROMPT, --prompt PROMPT
prompt to start generation with (default: random)
-n N, --n_predict N number of tokens to predict (default: 200)
--top_k N top-k sampling (default: 40)
--top_p N top-p sampling (default: 0.9)
--temp N temperature (default: 1.0)
-b N, --batch_size N batch size for prompt processing (default: 8)
-m FNAME, --model FNAME
model path (default: models/starcoder-117M/ggml-model.bin)

$ ./bin/starcoder -m ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin -p "def fibonnaci(" -t 4 --top_k 0 --top_p 0.95 --temp 0.2
main: seed = 1683881276
gpt2_model_load: loading model from '../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin'
gpt2_model_load: n_vocab = 49280
gpt2_model_load: n_ctx = 2048
gpt2_model_load: n_embd = 2048
gpt2_model_load: n_head = 16
gpt2_model_load: n_layer = 24
gpt2_model_load: ftype = 3
gpt2_model_load: ggml ctx size = 1794.90 MB
gpt2_model_load: memory size = 768.00 MB, n_mem = 49152
gpt2_model_load: model size = 1026.83 MB
main: prompt: 'def fibonnaci('
main: number of tokens in prompt = 7, first 8 tokens: 563 24240 78 2658 64 2819 7

def fibonnaci(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n-1) + fibonacci(n-2)

print(fibo(10))

main: mem per token = 9597928 bytes
main: load time = 480.43 ms
main: sample time = 26.21 ms
main: predict time = 3987.95 ms / 19.36 ms per token
main: total time = 4580.56 ms
```

## Quick start
```bash
git clone https://github.com/ggerganov/ggml
cd ggml

# Convert HF model to ggml
python examples/starcoder/convert-hf-to-ggml.py bigcode/gpt_bigcode-santacoder

# Build ggml + examples
mkdir build && cd build
cmake .. && make -j4 starcoder starcoder-quantize

# quantize the model
./bin/starcoder-quantize ../models/bigcode/gpt_bigcode-santacoder-ggml.bin ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin 3

# run inference
./bin/starcoder -m ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin -p "def fibonnaci(" --top_k 0 --top_p 0.95 --temp 0.2
```


## Downloading and converting the original models (💫 StarCoder)

You can download the original model and convert it to `ggml` format using the script `convert-hf-to-ggml.py`:

```
# Convert HF model to ggml
python examples/starcoder/convert-hf-to-ggml.py bigcode/gpt_bigcode-santacoder
```

This conversion requires that you have python and Transformers installed on your computer.

## Quantizing the models

You can also try to quantize the `ggml` models via 4-bit integer quantization.

```
# quantize the model
./bin/starcoder-quantize ../models/bigcode/gpt_bigcode-santacoder-ggml.bin ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin 3
```

| Model | Original size | Quantized size | Quantization type |
| --- | --- | --- | --- |
| `bigcode/gpt_bigcode-santacoder` | 5396.45 MB | 1026.83 MB | 4-bit integer (q4_1) |
| `bigcode/starcoder` | 71628.23 MB | 13596.23 MB | 4-bit integer (q4_1) |
212 changes: 212 additions & 0 deletions examples/starcoder/convert-hf-to-ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Convert HF models to ggml format
#

import sys
import struct
import json
import torch
import numpy as np
import re
import os

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BloomForCausalLM

# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

if len(sys.argv) < 2:
print("Usage: python convert-hf-to-ggml.py hf-model-name [use-f32]")
print("Example: python convert-hf-to-ggml.py bigcode/gpt_bigcode-santacoder")
print("Example: python convert-hf-to-ggml.py bigcode/starcoder")
sys.exit(1)

model_name = sys.argv[1].strip()
fname_out = "models/" + sys.argv[1].strip() + "-ggml.bin"
os.makedirs(os.path.dirname(fname_out), exist_ok=True)



# use 16-bit or 32-bit floats
use_f16 = True
if len(sys.argv) > 2:
use_f16 = False

print("Loading model: ", model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
hparams = config.to_dict()
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16 if use_f16 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True)
print("Model loaded: ", model_name)

#print (model)

list_vars = model.state_dict()
#print (list_vars)

encoder = tokenizer.vocab
# Add added_tokens (special tokens) to the encoder
encoder.update(tokenizer.get_added_vocab())
print(hparams)

print("Saving ggml model to: ", fname_out)
fout = open(fname_out, "wb")

fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
vocab_size = hparams["vocab_size"]
fout.write(struct.pack("i", vocab_size))
# fout.write(struct.pack("i", len(encoder)))
fout.write(struct.pack("i", hparams["n_positions"]))
fout.write(struct.pack("i", hparams["n_embd"]))
fout.write(struct.pack("i", hparams["n_head"]))
fout.write(struct.pack("i", hparams["n_layer"]))
fout.write(struct.pack("i", use_f16))

byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}

fout.write(struct.pack("i", vocab_size))

counter = 0
# sort by value
for key in sorted(encoder, key=encoder.get):
text = bytearray([byte_decoder[c] for c in key])
fout.write(struct.pack("i", len(text)))
fout.write(text)
counter += 1

# TODO: Repeat last token until vocab_size
while counter < vocab_size:
fout.write(struct.pack("i", len(text)))
fout.write(text)
counter += 1
# assert counter == config.vocab_size

for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
print("Processing variable: " + name + " with shape: ", data.shape)

# rename headers to keep compatibility
if name == "transformer.ln_f.weight":
name = "model/ln_f/g"
elif name == "transformer.ln_f.bias":
name = "model/ln_f/b"
elif name == "transformer.wte.weight":
name = "model/wte"
elif name == "transformer.wpe.weight":
name = "model/wpe"
elif name == "lm_head.weight":
name = "model/lm_head"
elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/ln_1/g"
elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/ln_1/b"
elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/attn/c_attn/w"
elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/attn/c_attn/b"
elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/attn/c_proj/w"
elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/attn/c_proj/b"
elif re.match(r"transformer.h.\d+.ln_2.weight", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/ln_2/g"
elif re.match(r"transformer.h.\d+.ln_2.bias", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/ln_2/b"
elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/mlp/c_fc/w"
elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/mlp/c_fc/b"
elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/mlp/c_proj/w"
elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
i = re.findall("\d+", name)[0]
name = f"model/h{i}/mlp/c_proj/b"
else:
print("Unrecognized variable name. %s", name)

# we don't need these
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
print(" Skipping variable: " + name)
continue

n_dims = len(data.shape);

# ftype == 0 -> float32, ftype == 1 -> float16
ftype = 0;
if use_f16:
if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or name[-2:] == "/w") and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype = 0

"model/h.*/attn/c_attn/w"
"model/h.*/attn/c_proj/w"
"model/h.*/mlp/c_fc/w"
"model/h.*/mlp/c_proj/w"
if name[-14:] == "/attn/c_attn/w" or name[-14:] == "/attn/c_attn/b":
print(" Duplicate K,V heads to use MHA instead of MQA")

embed_dim = hparams["n_embd"]
head_dim = embed_dim // hparams["n_head"]

# ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0)
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
if len(k.shape) == 2:
k = np.tile(k, (hparams["n_head"], 1))
v = np.tile(v, (hparams["n_head"], 1))
elif len(k.shape) == 1:
k = np.tile(k, (hparams["n_head"]))
v = np.tile(v, (hparams["n_head"]))
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
data = np.concatenate((q, k, v), axis=0)

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str);

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")
Loading