-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
local transformers rework, missing tests added, possiblity to overrid…
…e inference function, quantize_4bit: bool parameter
- Loading branch information
Showing
3 changed files
with
99 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,80 @@ | ||
import logging | ||
import gc | ||
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import transformers | ||
import torch | ||
|
||
from .local_llm import make_llm_functions as make_local_llm_functions | ||
from ..configuration import Config | ||
from ..types import LLMFunctionType, LLMAsyncFunctionType | ||
|
||
|
||
def make_llm_functions(config: Config, env=None) -> tuple[LLMFunctionType, LLMAsyncFunctionType]: | ||
def inference(prompt: str, model, tokenizer, **kwargs): | ||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | ||
outputs = model.generate(**inputs, **kwargs) | ||
out = tokenizer.decode(outputs[0][len(inputs[0]) :], skip_special_tokens=True) | ||
return out | ||
|
||
|
||
def make_llm_functions( | ||
config: Config, env | ||
) -> tuple[LLMFunctionType, LLMAsyncFunctionType]: | ||
logging.info(f"Loading local Transformers model {config.MODEL}...") | ||
params = config.INIT_PARAMS | ||
mc_param_names = ['model', 'tokenizer', 'device'] | ||
model_init_params = {k: v for k, v in params.items() if k not in mc_param_names} | ||
mc_params = {k: v for k, v in params.items() if k in mc_param_names} | ||
if 'model' not in mc_params or 'tokenizer' not in mc_params: | ||
|
||
tokenizer = params.get("tokenizer") or transformers.AutoTokenizer.from_pretrained( | ||
config.MODEL, trust_remote_code=True | ||
) | ||
|
||
if not (model := params.get("model")): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
if 'tokenizer' in mc_params: | ||
tokenizer = mc_params['tokenizer'] | ||
else: | ||
tokenizer = AutoTokenizer.from_pretrained(config.MODEL, trust_remote_code=True) | ||
if 'model' in mc_params: | ||
model = mc_params['model'] | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_config = transformers.AutoConfig.from_pretrained( | ||
config.MODEL, | ||
trust_remote_code=True, | ||
) | ||
mc_param_names = ["model", "tokenizer", "device", "quantize_4bit", "inference"] | ||
model_init_params = dict( | ||
**dict( | ||
**dict( | ||
trust_remote_code=True, | ||
torch_dtype="auto", | ||
device_map="auto", | ||
offload_folder=config.STORAGE_PATH, | ||
), | ||
**model_init_params | ||
trust_remote_code=True, | ||
torch_dtype="auto", | ||
config=model_config, | ||
device_map="auto", | ||
offload_folder=config.STORAGE_PATH, | ||
), | ||
**{k: v for k, v in params.items() if k not in mc_param_names}, | ||
) | ||
if "quantize_4bit" in params: | ||
model_init_params["quantization_config"] = transformers.BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
bnb_4bit_use_double_quant=True, | ||
) | ||
|
||
model = transformers.AutoModelForCausalLM.from_pretrained( | ||
config.MODEL, **model_init_params | ||
) | ||
if 'device' in mc_params: | ||
model.to(mc_params['device']) | ||
if "device" in params: | ||
model.to(params["device"]) | ||
|
||
transformers_model_args = { | ||
"max_new_tokens": 2048, | ||
} | ||
if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id: | ||
transformers_model_args['eos_token_id'] = tokenizer.eos_token_id | ||
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: | ||
transformers_model_args["eos_token_id"] = tokenizer.eos_token_id | ||
|
||
setattr(env, "tokenizer", tokenizer) | ||
setattr(env, "model", model) | ||
setattr(env, "inference", params.get("inference") or inference) | ||
|
||
def inference(prompt: dict | str, **kwargs): | ||
def wrapped_inference(prompt: dict | str, **kwargs): | ||
if config.CHAT_MODE: | ||
prompt = tokenizer.apply_chat_template( | ||
prompt, | ||
add_generation_prompt=True, | ||
tokenize=False | ||
prompt, add_generation_prompt=True, tokenize=False | ||
) | ||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | ||
outputs = model.generate(**inputs, **{**transformers_model_args, **kwargs}) | ||
out = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) | ||
return out | ||
|
||
if env: | ||
setattr(env, 'tokenizer', tokenizer) | ||
setattr(env, 'model', model) | ||
args = {**transformers_model_args, **kwargs} | ||
return env.inference(prompt, model=env.model, tokenizer=env.tokenizer, **args) | ||
|
||
logging.debug(f"Local Transformers model loaded: {config.MODEL}") | ||
|
||
return make_local_llm_functions(config, inference) | ||
return make_local_llm_functions(config, wrapped_inference) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import microcore as mc | ||
import pytest | ||
from microcore.configuration import Config, LLMConfigError | ||
|
||
|
||
def test_no_model(): | ||
# should raise an error | ||
with pytest.raises(LLMConfigError): | ||
# No model name | ||
mc.configure(LLM_API_TYPE=mc.ApiType.LOCAL_TRANSFORMERS, CHAT_MODE=True) | ||
|
||
|
||
def test(): | ||
import os | ||
os.environ['NVIDIA_VISIBLE_DEVICES'] = 'all' | ||
defaults = dict( | ||
LLM_API_TYPE=mc.ApiType.LOCAL_TRANSFORMERS, | ||
LLM_DEFAULT_ARGS={ | ||
"max_new_tokens": 30, | ||
}, | ||
INIT_PARAMS={ | ||
'quantize_4bit': True, | ||
} | ||
) | ||
configs = [ | ||
Config( | ||
MODEL='microsoft/phi-1_5', | ||
CHAT_MODE=False, | ||
**defaults | ||
), | ||
Config( | ||
MODEL='deepseek-ai/deepseek-coder-1.3b-instruct', | ||
CHAT_MODE=False, | ||
**defaults | ||
), | ||
Config( | ||
MODEL='google/gemma-2b-it', | ||
CHAT_MODE=True, | ||
**defaults | ||
), | ||
] | ||
for config in configs: | ||
mc.configure(**dict(config)) | ||
mc.use_logging() | ||
assert '3' in mc.llm('Count from 1 to 3: 1..., 2...') |