This repository has been archived by the owner on Oct 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NER plugin to neural chat (#436)
* add ner plugin code for neural chat Signed-off-by: LetongHan <letong.han@intel.com>
- Loading branch information
Showing
20 changed files
with
666 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
intel_extension_for_transformers/neural_chat/pipeline/plugins/ner/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
<div align="center"> | ||
<h1>NER: Named Entity Recognition</h3> | ||
<div align="left"> | ||
|
||
# 🏠Introduction | ||
The Named Entity Recognition(NER) Plugin is a software component designed to enhance NER-related functionality in Neural Chat. NER is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc. | ||
|
||
This plugin provides large language models(llm) with different precision, the supported precisions are listed below: | ||
|
||
1. Float32 (FP32) | ||
2. BFloat16 (BF16) | ||
3. INT8 | ||
4. INT4 | ||
|
||
As the precision decreases, the llm model inference time decreases, but so does the accuracy of the inference results. You can choose models with different accuracies to accomplish NER inferencing tasks according to your needs. | ||
|
||
# 🔧Install dependencies | ||
Before using NER plugin, install dependencies below. You need to download spacy model before using it, choose proper one according to Spacy Official Documents [here](https://spacy.io/usage/models) to meet your demands. The example model here is `en_core_web_lg`. | ||
|
||
```bash | ||
# install required packages | ||
pip install -r requirements.txt | ||
# download spacy model | ||
python -m spacy download en_core_web_lg | ||
``` | ||
|
||
|
||
# 🚀Usage | ||
With different model configurations, the NER plugins of varing precisions have two ways of usage as below. | ||
|
||
## Inference with FP32/BF16 | ||
|
||
```python | ||
from intel_extension_for_transformers.neural_chat.pipeline.plugins.ner import NamedEntityRecognition | ||
ner_obj = NamedEntityRecognition | ||
# modify the query here for customized NER task | ||
query = "Show me photos taken in Shanghai today." | ||
result = ner_obj.inference(query=query) # add argument "bf16=True" for BF16 inference | ||
print("NER result: ", result) | ||
``` | ||
|
||
## Inference with INT8/INT4 | ||
|
||
```python | ||
from ner_int import NamedEntityRecognitionINT | ||
# set compute_dtype='int8' and weight_dtype='int4' for INT4 inference | ||
ner_obj = NamedEntityRecognitionINT(compute_dtype='fp32', weight_dtype='int8') | ||
query = "Show me photos taken in Shanghai today." | ||
result = ner_obj.inference(query=query) | ||
print("NER result: ", result) | ||
``` | ||
|
||
# 🚗Parameters | ||
## Plugin Parameters | ||
You can costomize the NER inference parameters to meet the personal demands for better performance. You can set the specific parameter by `plugins.ner.args["xxx"]`. Below are the descriptions of each available parameters. | ||
```python | ||
model_name_or_path [str]: The huggingface model name or local path of the downloaded llm model. Default to "./Llama-2-7b-chat-hf/". | ||
|
||
spacy_model [str]: The Spacy model for NLP process, specify it according to the downloaded Spacy model. Default to "en_core_web_lg". | ||
|
||
bf16 [bool]: Choose wether to use BF16 precision for NER inference. Default to False. | ||
``` | ||
As for INT8 and INT4 model the plugin parameters are slightly different. You can set the specific parameter by `plugins.ner_int.args["xxx"]`. | ||
```python | ||
model_name_or_path [str]: The huggingface model name or local path of the downloaded llm model. Default to "./Llama-2-7b-chat-hf/". | ||
|
||
spacy_model [str]: The Spacy model for NLP process, specify it according to the downloaded Spacy model. Default to "en_core_web_lg". | ||
|
||
compute_dtype [str]: The dtype of model while computing. Set to "int8" for INT4 inference for better performance. Default to "fp32". | ||
|
||
weight_dtype [str]: The dtype of model weight. Set to "int4" for INT4 inference. Default to "int8". | ||
``` | ||
|
||
## Inference Parameters | ||
### FP32/BF16 Inference | ||
```python | ||
query [str]: The query string that needs NER to extract entities. Mandatory parameter that must be passed. | ||
|
||
prompt [str]: The inference prompt for llm model. You could construct customized prompt for certain demands. Default to 'construct_default_prompt' in '/neural_chat/pipeline/plugins/ner/utils/utils.py'. | ||
|
||
max_new_tokens [int]: The max generated token numbers. Default to 32. | ||
|
||
temperature [float]: The temperature of llm. Default to 0.01. | ||
|
||
top_k [int]: The top_k parameter of llm. Default to 3. | ||
|
||
repetition_penalty [float]: The repetition penalty of llm. Default to 1.1. | ||
``` | ||
|
||
### INT8/INT4 Inference | ||
```python | ||
query [str]: The query string that needs NER to extract entities. Mandatory parameter that must be passed. | ||
|
||
prompt [str]: The inference prompt for llm model. You could construct customized prompt for certain demands. Default to 'construct_default_prompt' in '/neural_chat/pipeline/plugins/ner/utils/utils.py'. | ||
|
||
threads [int]: The thread number of model inference. Set to the core number of your server for minimal inferencing time. Default to 52. | ||
|
||
max_new_tokens [int]: The max generated token numbers. Default to 32. | ||
|
||
seed [int]: The random seed of llm. Default to 1234. | ||
``` |
16 changes: 16 additions & 0 deletions
16
intel_extension_for_transformers/neural_chat/pipeline/plugins/ner/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
154 changes: 154 additions & 0 deletions
154
intel_extension_for_transformers/neural_chat/pipeline/plugins/ner/ner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# !/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from threading import Thread | ||
import re | ||
import time | ||
import torch | ||
import spacy | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
TextIteratorStreamer, | ||
AutoConfig, | ||
) | ||
import intel_extension_for_pytorch as intel_ipex | ||
from .utils.utils import ( | ||
set_cpu_running_env, | ||
enforce_stop_tokens, | ||
get_current_time | ||
) | ||
from .utils.process_text import process_time, process_entities | ||
from intel_extension_for_transformers.neural_chat.prompts import PromptTemplate | ||
|
||
|
||
class NamedEntityRecognition(): | ||
""" | ||
NER class to inference with fp32 or bf16 llm models | ||
Set bf16=True if you want to inference with bf16 model. | ||
""" | ||
|
||
def __init__(self, model_path="./Llama-2-7b-chat-hf/", spacy_model="en_core_web_lg", bf16: bool=False) -> None: | ||
# set up cpu running environment | ||
if bf16: | ||
set_cpu_running_env() | ||
# initialize tokenizer and models | ||
self.nlp = spacy.load(spacy_model) | ||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
config.init_device = 'cuda:0' if torch.cuda.is_available() else "cpu" | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
model_path, | ||
use_fast=False if (re.search("llama", model_path, re.IGNORECASE) | ||
or re.search("neural-chat-7b-v2", model_path, re.IGNORECASE)) else True, | ||
trust_remote_code=True | ||
) | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_path, | ||
torch_dtype=torch.bfloat16, | ||
config=config, | ||
device_map="auto", | ||
trust_remote_code=True | ||
) | ||
self.bf16 = bf16 | ||
# optimize model with ipex if bf16 | ||
if bf16: | ||
self.model = intel_ipex.optimize( | ||
self.model.eval(), | ||
dtype=torch.bfloat16, | ||
inplace=True, | ||
level="O1", | ||
auto_kernel_selection=True, | ||
) | ||
for i in range(3): | ||
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id = i | ||
print("[NER info] Spacy and LLM model initialized.") | ||
|
||
|
||
def inference(self, | ||
query: str, | ||
prompt: str=None, | ||
max_new_tokens: int=32, | ||
temperature: float=0.01, | ||
top_k: int=3, | ||
repetition_penalty: float=1.1): | ||
start_time = time.time() | ||
cur_time = get_current_time() | ||
print("[NER info] Current time is:{}".format(cur_time)) | ||
if not prompt: | ||
pt = PromptTemplate("ner") | ||
pt.append_message(pt.conv.roles[0], cur_time) | ||
pt.append_message(pt.conv.roles[1], query) | ||
prompt = pt.get_prompt() | ||
print("[NER info] Prompt is: ", prompt) | ||
inputs= self.tokenizer(prompt, return_token_type_ids=False, return_tensors="pt") | ||
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=False) | ||
|
||
# define generate kwargs and construct inferencing thread | ||
if self.bf16: | ||
generate_kwargs = dict( | ||
max_new_tokens=max_new_tokens, | ||
temperature=temperature, | ||
top_k=top_k, | ||
repetition_penalty=repetition_penalty, | ||
) | ||
|
||
def generate_output(): | ||
dtype = self.model.dtype if hasattr(self.model, 'dtype') else torch.bfloat16 | ||
try: | ||
with torch.no_grad(): | ||
context = torch.cpu.amp.autocast(enabled=True, dtype=dtype, cache_enabled=True) | ||
with context: | ||
output_token=self.model.generate( | ||
**inputs, | ||
**generate_kwargs, | ||
streamer=streamer, | ||
return_dict_in_generate=True, | ||
) | ||
return output_token | ||
except Exception as e: | ||
raise Exception(e) | ||
|
||
thread = Thread(target=generate_output) | ||
|
||
else: | ||
generate_kwargs = dict( | ||
inputs, | ||
max_new_tokens=max_new_tokens, | ||
temperature=temperature, | ||
top_k=top_k, | ||
repetition_penalty=repetition_penalty, | ||
streamer=streamer, | ||
) | ||
thread = Thread(target=self.model.generate, kwargs=generate_kwargs) | ||
|
||
# inference with thread | ||
thread.start() | ||
text = "" | ||
for new_text in streamer: | ||
text += new_text | ||
|
||
# process inferenced tokens and texts | ||
text = enforce_stop_tokens(text) | ||
doc = self.nlp(text) | ||
mentioned_time = process_time(text, doc) | ||
|
||
new_doc = self.nlp(query) | ||
result = process_entities(query, new_doc, mentioned_time) | ||
print("[NER info] Inference time consumption: ", time.time() - start_time) | ||
|
||
return result | ||
|
85 changes: 85 additions & 0 deletions
85
intel_extension_for_transformers/neural_chat/pipeline/plugins/ner/ner_int.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# !/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import time | ||
import spacy | ||
from transformers import AutoTokenizer, TextIteratorStreamer | ||
from intel_extension_for_transformers.transformers import ( | ||
AutoModelForCausalLM, | ||
WeightOnlyQuantConfig | ||
) | ||
from .utils.utils import ( | ||
enforce_stop_tokens, | ||
get_current_time | ||
) | ||
from .utils.process_text import process_time, process_entities | ||
from intel_extension_for_transformers.neural_chat.prompts import PromptTemplate | ||
|
||
|
||
class NamedEntityRecognitionINT(): | ||
""" | ||
Initialize spacy model and llm model | ||
If you want to inference with int8 model, set compute_dtype='fp32' and weight_dtype='int8'; | ||
If you want to inference with int4 model, set compute_dtype='int8' and weight_dtype='int4'. | ||
""" | ||
|
||
def __init__(self, | ||
model_path="/home/tme/Llama-2-7b-chat-hf/", | ||
spacy_model="en_core_web_lg", | ||
compute_dtype="fp32", | ||
weight_dtype="int8") -> None: | ||
self.nlp = spacy.load(spacy_model) | ||
config = WeightOnlyQuantConfig(compute_dtype=compute_dtype, weight_dtype=weight_dtype) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
self.model = AutoModelForCausalLM.from_pretrained(model_path, | ||
quantization_config=config, | ||
trust_remote_code=True) | ||
print("[NER info] Spacy and LLM model initialized.") | ||
|
||
|
||
def inference(self, query: str, prompt: str=None, threads: int=52, max_new_tokens: int=32, seed: int=1234): | ||
start_time = time.time() | ||
cur_time = get_current_time() | ||
print("[NER info] Current time is:{}".format(cur_time)) | ||
if not prompt: | ||
pt = PromptTemplate("ner") | ||
pt.append_message(pt.conv.roles[0], cur_time) | ||
pt.append_message(pt.conv.roles[1], query) | ||
prompt = pt.get_prompt() | ||
print("[NER info] Prompt is: ", prompt) | ||
inputs = self.tokenizer(prompt, return_tensors="pt").input_ids | ||
streamer = TextIteratorStreamer(self.tokenizer) | ||
|
||
result_tokens = self.model.generate( | ||
inputs, | ||
streamer=streamer, | ||
max_new_tokens=max_new_tokens, | ||
seed=seed, | ||
threads=threads | ||
) | ||
self.model.model.reinit() | ||
gen_text = self.tokenizer.batch_decode(result_tokens) | ||
|
||
result_text = enforce_stop_tokens(gen_text[0]) | ||
doc = self.nlp(result_text) | ||
mentioned_time = process_time(result_text, doc) | ||
|
||
new_doc = self.nlp(query) | ||
result = process_entities(query, new_doc, mentioned_time) | ||
print("[NER info] Inference time consumption: ", time.time() - start_time) | ||
|
||
return result |
7 changes: 7 additions & 0 deletions
7
intel_extension_for_transformers/neural_chat/pipeline/plugins/ner/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
torch | ||
torchvision | ||
torchaudio | ||
spacy | ||
neural-compressor | ||
transformers | ||
intel_extension_for_pytorch |
16 changes: 16 additions & 0 deletions
16
intel_extension_for_transformers/neural_chat/pipeline/plugins/ner/utils/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.