Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gguf : add ftype meta info to the model #2710

Merged
merged 3 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ class UnquantizedDataType:
'I32': DT_I32,
}

class GGMLFileType(enum.Enum):
# TODO: match this with `llama_ftype`
# TODO: rename to LLAMAFileType
# TODO: move to `gguf.py`
class GGMLFileType(enum.IntEnum):
AllF32 = 0
MostlyF16 = 1 # except 1d tensors

Expand Down Expand Up @@ -101,6 +104,8 @@ class Params:
n_head_kv: int
f_norm_eps: float

ftype: Optional[GGMLFileType] = None

@staticmethod
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
Expand Down Expand Up @@ -738,6 +743,9 @@ def add_meta_arch(self, params: Params) -> None:
self.gguf.add_head_count_kv (params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)

if params.ftype:
self.gguf.add_file_type(params.ftype)
ggerganov marked this conversation as resolved.
Show resolved Hide resolved

def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = []
scores = []
Expand Down Expand Up @@ -1020,6 +1028,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
" - LLaMA v2: --ctx 4096\n")
params.n_ctx = args.ctx

if args.outtype:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
}[args.outtype]

print(f"params = {params}")

vocab: Vocab
Expand All @@ -1040,11 +1054,14 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype)

model = model_plus.model
model = convert_model_names(model, params)
output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type)
outfile = args.outfile or default_outfile(model_plus.paths, output_type)
model = model_plus.model
model = convert_model_names(model, params)
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype)

params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")

OutputFile.write_all(outfile, params, model, vocab)
print(f"Wrote {outfile}")
Expand Down
4 changes: 4 additions & 0 deletions gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
KEY_GENERAL_LICENSE = "general.license"
KEY_GENERAL_SOURCE_URL = "general.source.url"
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
KEY_GENERAL_FILE_TYPE = "general.file_type"

# LLM
KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length"
Expand Down Expand Up @@ -595,6 +596,9 @@ def add_source_url(self, url: str):
def add_source_hf_repo(self, repo: str):
self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)

def add_file_type(self, ftype: int):
self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)

def add_name(self, name: str):
self.add_string(KEY_GENERAL_NAME, name)

Expand Down
21 changes: 18 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,16 @@ struct llama_model_loader {
} break;
}

// this is a way to mark that we have "guessed" the file type
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);

{
const int kid = gguf_find_key(ctx_gguf, "general.file_type");
if (kid >= 0) {
ftype = (llama_ftype) gguf_get_val_u32(ctx_gguf, kid);
}
}

for (int i = 0; i < n_kv; i++) {
const char * name = gguf_get_key(ctx_gguf, i);
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
Expand Down Expand Up @@ -1323,7 +1333,11 @@ struct llama_model_loader {
// load LLaMA models
//

const char * llama_model_ftype_name(enum llama_ftype ftype) {
std::string llama_model_ftype_name(enum llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}

switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16";
Expand Down Expand Up @@ -1552,7 +1566,7 @@ static void llama_model_load_internal(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_elements*1e-9);

// general kv
Expand Down Expand Up @@ -3620,6 +3634,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// copy the KV pairs from the input file
gguf_set_kv (ctx_out, model_loader->ctx_gguf);
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
gguf_set_val_u32(ctx_out, "general.file_type", ftype);

#ifdef GGML_USE_K_QUANTS
int n_attention_wv = 0;
Expand Down Expand Up @@ -4471,7 +4486,7 @@ int llama_model_n_embd(const struct llama_model * model) {
}

int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype));
return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str());
}

int llama_model_quantize(
Expand Down
2 changes: 2 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors

LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};

typedef struct llama_token_data {
Expand Down