-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_model_type.py
50 lines (46 loc) · 1.82 KB
/
get_model_type.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from get_env import get_env
def get_model_type(
filename: str,
) -> str:
ctransformer_model_type = "llama"
filename = filename.lower()
# These are also in "starcoder" format
# https://huggingface.co/TheBloke/WizardCoder-15B-1.0-GGML
# https://huggingface.co/TheBloke/minotaur-15B-GGML
if (
"star" in filename
or "starchat" in filename
or "WizardCoder" in filename
or "minotaur-15" in filename
):
ctransformer_model_type = "gpt_bigcode"
if "llama" in filename:
ctransformer_model_type = "llama"
if "mpt" in filename:
ctransformer_model_type = "mpt"
if "replit" in filename:
ctransformer_model_type = "replit"
if "falcon" in filename:
ctransformer_model_type = "falcon"
if "dolly" in filename:
ctransformer_model_type = "dolly-v2"
if "stablelm" in filename:
ctransformer_model_type = "gpt_neox"
# matching https://huggingface.co/stabilityai/stablecode-completion-alpha-3b
if "stablecode" in filename:
ctransformer_model_type = "gpt_neox"
# matching https://huggingface.co/EleutherAI/pythia-70m
if "pythia" in filename:
ctransformer_model_type = "gpt_neox"
# codegen family are in gptj, codegen2 isn't but not supported by ggml/ctransformer yet
# https://huggingface.co/Salesforce/codegen-2B-multi
# https://huggingface.co/ravenscroftj/CodeGen-2B-multi-ggml-quant
if "codegen" in filename:
ctransformer_model_type = "gptj"
DEFAULT_MODEL_HG_REPO_ID = get_env("DEFAULT_MODEL_HG_REPO_ID", "")
if "gptq" in str(DEFAULT_MODEL_HG_REPO_ID).lower() or "gptq" in filename:
ctransformer_model_type = "gptq"
MODE_TYPE = get_env("MODE_TYPE", "")
if len(MODE_TYPE) > 0:
ctransformer_model_type = MODE_TYPE
return ctransformer_model_type