Skip to content

Commit

Permalink
fix: fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jun 5, 2024
1 parent 4bce30c commit 3b44f8f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
2 changes: 1 addition & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
# don't edit the hashes manually!
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
res = "llama-bpe"
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
res = "deepseek-llm"
Expand Down
41 changes: 20 additions & 21 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4653,16 +4653,7 @@ static void llm_load_vocab(

// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
if (tokenizer_pre.empty()) {
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__);
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
if (
tokenizer_pre == "default") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
Expand Down Expand Up @@ -4715,7 +4706,8 @@ static void llm_load_vocab(
tokenizer_pre == "smaug-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
Expand Down Expand Up @@ -5569,7 +5561,7 @@ static bool llm_load_tensors(
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);

layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});

layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
Expand Down Expand Up @@ -6631,7 +6623,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
}
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return -1;
throw;
}

return 0;
Expand Down Expand Up @@ -16254,16 +16246,23 @@ struct llama_model * llama_load_model_from_file(
}
model->rpc_servers.push_back(servers);
}
int status = llama_model_load(path_model, *model, params);
GGML_ASSERT(status <= 0);
if (status < 0) {
if (status == -1) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
} else if (status == -2) {
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);

try {
int status = llama_model_load(path_model, *model, params);
GGML_ASSERT(status <= 0);
if (status < 0) {
if (status == -1) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
} else if (status == -2) {
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
}
delete model;
return nullptr;
}
} catch (...) {
LLAMA_LOG_ERROR("%s: exception loading model\n", __func__);
delete model;
return nullptr;
throw;
}

return model;
Expand Down

0 comments on commit 3b44f8f

Please sign in to comment.