Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Adept Persimmon 8b #3410

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7cdc3ea
Produces garbage output
phillip-kravtsov Sep 21, 2023
4bcf412
wip: correct tensors up to RoPE
phillip-kravtsov Sep 26, 2023
c9e1446
correct tensors thru RoPE
phillip-kravtsov Sep 26, 2023
d1b40ef
Correct outputs through masked & softmax'd KQ
phillip-kravtsov Sep 26, 2023
db2181a
fp32 works
phillip-kravtsov Sep 26, 2023
3f31799
Rename adept->persimmon
phillip-kravtsov Sep 28, 2023
720503b
Merge branch 'master' of github.com:phillip-kravtsov/llama.cpp into p…
phillip-kravtsov Sep 28, 2023
d61eed0
Produces correct outputs
phillip-kravtsov Sep 29, 2023
d0a7143
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 29, 2023
fa92f6e
clean up convert scripts
phillip-kravtsov Sep 29, 2023
c28a6c5
remove printing logic from ggml.c
phillip-kravtsov Sep 29, 2023
47dcb9f
remove prints from llama.cpp & fix merge
phillip-kravtsov Sep 29, 2023
7473773
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 29, 2023
d904aff
trivial cleanups
phillip-kravtsov Sep 29, 2023
ec0ce97
Add offload funcs
phillip-kravtsov Sep 29, 2023
3db04db
update conversion script to directly take adept artifacts rather than…
phillip-kravtsov Sep 29, 2023
f28f52c
Fix norm eps bug
phillip-kravtsov Sep 29, 2023
d93cf1e
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 29, 2023
574a9e1
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 30, 2023
2b56591
Support sqr and concat on metal, persimmon-8b-q4 runs correctly
phillip-kravtsov Sep 30, 2023
e6bf87f
Small changes from review
phillip-kravtsov Oct 2, 2023
cd4d3df
Formatting changes
phillip-kravtsov Oct 2, 2023
422b110
Minor changes to conversion script
phillip-kravtsov Oct 2, 2023
5a0990c
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Oct 2, 2023
7a279fe
Remove old script
phillip-kravtsov Oct 2, 2023
c90ed9f
Fix editorconfig formatting
phillip-kravtsov Oct 3, 2023
5d259d3
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Oct 5, 2023
1d518d6
Fix build
phillip-kravtsov Oct 5, 2023
0c1a8f6
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Oct 6, 2023
485a471
add overlooked offload code ggml-ci
phillip-kravtsov Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions convert-persimmon-to-gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
import os
from pprint import pprint
import sys
import argparse
from pathlib import Path
from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf

def _flatten_dict(dct, tensors, prefix=None):
assert isinstance(dct, dict)
for key in dct.keys():
new_prefix = prefix + '.' + key if prefix is not None else key
if isinstance(dct[key], torch.Tensor):
tensors[new_prefix] = dct[key]
elif isinstance(dct[key], dict):
_flatten_dict(dct[key], tensors, new_prefix)
else:
raise ValueError(type(dct[key]))
return None

def _get_sentencepiece_tokenizer_info(dir_model: Path):
tokenizer_path = dir_model / 'adept_vocab.model'
print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
tokenizer = SentencePieceProcessor(str(tokenizer_path))
print('gguf: adding tokens')
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []

for i in range(tokenizer.vocab_size()):
text: bytes
score: float

piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)

toktype = 1
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6

tokens.append(text)
scores.append(score)
toktypes.append(toktype)
pass
return tokens, scores, toktypes

def main():
parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
args = parser.parse_args()
sys.path.append(str(args.adept_inference_dir))
persimmon_model = torch.load(args.ckpt_path)
hparams = persimmon_model['args']
pprint(hparams)
tensors = {}
_flatten_dict(persimmon_model['model'], tensors, None)

arch = gguf.MODEL_ARCH.PERSIMMON
gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])

block_count = hparams.num_layers
head_count = hparams.num_attention_heads
head_count_kv = head_count
ctx_length = hparams.seq_length
hidden_size = hparams.hidden_size

gguf_writer.add_name('persimmon-8b-chat')
gguf_writer.add_context_length(ctx_length)
gguf_writer.add_embedding_length(hidden_size)
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
gguf_writer.add_rope_dimension_count(hidden_size // head_count)
gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)
gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon)

tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir)
gguf_writer.add_tokenizer_model('llama')
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
gguf_writer.add_bos_token_id(71013)
gguf_writer.add_eos_token_id(71013)

tensor_map = gguf.get_tensor_name_map(arch, block_count)
print(tensor_map)
for name in tensors.keys():
data = tensors[name]
if name.endswith(".self_attention.rotary_emb.inv_freq"):
continue
old_dtype = data.dtype
# TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
data = data.to(torch.float32).squeeze().numpy()
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()
n_dims = len(data.shape)
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print(f"gguf: model successfully exported to '{args.outfile}'")
print("")



if __name__ == '__main__':
main()
54 changes: 54 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);

#undef GGML_METAL_DECL_KERNEL
};
Expand Down Expand Up @@ -300,6 +302,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);

#undef GGML_METAL_ADD_KERNEL
}
Expand Down Expand Up @@ -375,6 +379,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);

#undef GGML_METAL_DEL_KERNEL

Expand Down Expand Up @@ -766,6 +772,43 @@ void ggml_metal_graph_compute(
{
// noop
} break;
case GGML_OP_CONCAT:
{

int64_t nb = ne00;
[encoder setComputePipelineState:ctx->pipeline_concat];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];

const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ADD:
{
GGML_ASSERT(ggml_is_contiguous(src0));
Expand Down Expand Up @@ -903,6 +946,17 @@ void ggml_metal_graph_compute(
GGML_ASSERT(false);
}
} break;
case GGML_OP_SQR:
phillip-kravtsov marked this conversation as resolved.
Show resolved Hide resolved
{
GGML_ASSERT(ggml_is_contiguous(src0));

[encoder setComputePipelineState:ctx->pipeline_sqr];
phillip-kravtsov marked this conversation as resolved.
Show resolved Hide resolved
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
const int nth = MIN(32, ne00);
Expand Down
63 changes: 63 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]);
}

kernel void kernel_sqr(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}

constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;

Expand Down Expand Up @@ -1098,6 +1105,62 @@ kernel void kernel_cpy_f32_f32(
}
}

kernel void kernel_concat(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {

const int64_t i03 = tgpig.z;
const int64_t i02 = tgpig.y;
const int64_t i01 = tgpig.x;

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;

for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i02 < ne02) {
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
src0_ptr += ntg.x*nb00;
} else {
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
src1_ptr += ntg.x*nb10;
}
dst_ptr += ntg.x*nb0;
}
}

//============================================ k-quants ======================================================

#ifndef QK_K
Expand Down
Loading