Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[NeuralChat] enable RAG + ChatGPT flow (#1322)
Browse files Browse the repository at this point in the history
* Enable openai service

* RAG+chatgpt example

Signed-off-by: Sihan Chen <39623753+Spycsh@users.noreply.github.com>
  • Loading branch information
Spycsh authored Feb 28, 2024
1 parent 2858ed1 commit de88006
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 21 deletions.
7 changes: 6 additions & 1 deletion intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .config import PipelineConfig
from .config import BaseFinetuningConfig
from .plugins import plugins
from .utils.common import is_openai_model

from .errorcode import ErrorCodes
from .utils.error_utils import set_latest_error, get_latest_error, clear_latest_error
Expand Down Expand Up @@ -113,6 +114,7 @@ def build_chatbot(config: PipelineConfig=None):
if not config:
config = PipelineConfig()


if config.hf_endpoint_url:
if not config.hf_access_token:
set_latest_error(ErrorCodes.ERROR_HF_TOKEN_NOT_PROVIDED)
Expand All @@ -123,7 +125,10 @@ def build_chatbot(config: PipelineConfig=None):
adapter = HuggingfaceModel(config.hf_endpoint_url, config.hf_access_token)
else:
# create model adapter
if "llama" in config.model_name_or_path.lower():
if is_openai_model(config.model_name_or_path.lower()):
from .models.openai_model import OpenAIModel
adapter = OpenAIModel(config.model_name_or_path, config.task, config.openai_config)
elif "llama" in config.model_name_or_path.lower():
from .models.llama_model import LlamaModel
adapter = LlamaModel(config.model_name_or_path, config.task)
elif "mpt" in config.model_name_or_path.lower():
Expand Down
16 changes: 15 additions & 1 deletion intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .utils.common import get_device_type

from .plugins import plugins
from .utils.common import is_openai_model
import os

from enum import Enum, auto

Expand Down Expand Up @@ -450,6 +452,11 @@ class ServingConfig:
framework: str = "vllm" # vllm/TGI
framework_config: FrameworkConfig = None

@dataclass
class OpenAIConfig:
def __init__(self, api_key: str = None, organization: str = None):
self.api_key = api_key if api_key else os.environ.get("OPENAI_API_KEY")
self.organization = organization if organization else os.environ.get("OPENAI_ORG")
class PipelineConfig:
def __init__(self,
model_name_or_path="Intel/neural-chat-7b-v3-1",
Expand All @@ -462,8 +469,15 @@ def __init__(self,
loading_config=None,
optimization_config=None,
assistant_model=None,
serving_config=None):
serving_config=None,
openai_config=None,):
self.model_name_or_path = model_name_or_path

if is_openai_model(model_name_or_path.lower()):
self.openai_config = openai_config if openai_config else OpenAIConfig()
if self.openai_config.api_key is None:
raise Exception("Please provide the OpenAI key if you are using OpenAI model!")

self.tokenizer_name_or_path = tokenizer_name_or_path
self.hf_access_token = hf_access_token
self.hf_endpoint_url = hf_endpoint_url
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ device: "cpu"
# bnb_4bit_use_double_quant: True
# bnb_4bit_compute_dtype: "bfloat16"

# retrieval:
# enable: true
# args:
# input_path: "./docs"
# persist_directory: "./docs_persist"
# response_template: "We cannot find suitable content to answer your query at this moment."
# append: True

# task choices = ['textchat', 'voicechat', 'retrieval', 'text2image', 'finetune']
tasks_list: ['textchat']
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel
from openai import OpenAI
import logging
from ..config import GenerationConfig
from ..plugins import is_plugin_enabled, get_plugin_instance, get_registered_plugins
import os
from ..utils.error_utils import set_latest_error
from ..errorcode import ErrorCodes


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

class OpenAIModel(BaseModel):
"""
Customized class to operate on OpenAI models in the pipeline.
"""
def __init__(self, model_name, task, openai_config):
self.api_key = openai_config.api_key
self.organization = openai_config.organization
self.model_name = model_name
self.task = task

def load_model(self, kwargs: dict):
self.client = OpenAI(api_key=self.api_key, organization=self.organization)

def predict(self, query, config: GenerationConfig = None):
"""Customized OpenAI model predict.
Args:
query: List[Dict], usually contains system prompt + user prompt.
config: GenerationConfig, provides the needed inference parameters such as top_p, max_tokens.
Returns:
the result string of one single choice
"""
if not config:
config = GenerationConfig()

# Only supported retrieval plugin for now
plugin_name = "retrieval"
if is_plugin_enabled(plugin_name):
plugin_instance = get_plugin_instance(plugin_name)
try:
new_user_prompt, link = plugin_instance.pre_llm_inference_actions(self.model_name,
self.find_user_prompt(query))
self.update_user_prompt(query, new_user_prompt)
except Exception as e:
if "[Rereieval ERROR] intent detection failed" in str(e):
set_latest_error(ErrorCodes.ERROR_INTENT_DETECT_FAIL)
return
assert query is not None, "Query cannot be None."
response = self.client.chat.completions.create(
model=self.model_name,
messages=query,
temperature=config.temperature,
top_p=config.top_p,
max_tokens=config.max_new_tokens,
)
return response.choices[0].message.content

def find_user_prompt(self, query):
"""Find in the query List[Dict] the user prompt."""
return [i['content'] for i in query if 'role' in i and i['role'] == 'user'][0]

def update_user_prompt(self, query, new_user_prompt):
"""Update the user prompt in the query List[Dict]."""
for i in query:
if 'role' in i and i['role'] == 'user':
i['content'] = new_user_prompt
return query

def predict_stream(self, query, config: GenerationConfig = None):
raise Exception("Currently not support streaming! Will fix this in the future.")
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompt.prompt_template \
import generate_intent_prompt
from intel_extension_for_transformers.neural_chat.models.model_utils import predict

from intel_extension_for_transformers.neural_chat.utils.common import is_openai_model

class IntentDetector:
def __init__(self):
pass

def intent_detection(self, model_name, query):
"""Using the LLM to detect the intent of the user query."""
if is_openai_model(model_name):
return query
prompt = generate_intent_prompt(query)
params = {}
params["model_name"] = model_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import json, types
import tiktoken
from ...plugins import plugins, is_plugin_enabled
from ...utils.common import is_openai_model

def check_requests(request) -> Optional[JSONResponse]:
# Check all params
Expand Down Expand Up @@ -356,9 +357,9 @@ async def generate_completion(payload: Dict[str, Any], chatbot: BaseModel):
config = GenerationConfig()
for attr, value in payload.items():
setattr(config, attr, value)
config.device = chatbot.device
config.device = chatbot.device if hasattr(chatbot, "device") else "auto"
config.task = "chat"
if chatbot.device == "hpu":
if config.device == "hpu":
config.use_hpu_graphs = True
prompt = payload["prompt"]
response = chatbot.predict(query=prompt, config=config)
Expand Down Expand Up @@ -489,20 +490,29 @@ async def create_chat_completion(request: ChatCompletionRequest):

chatbot = router.get_chatbot()

gen_params = await get_generation_parameters(
request.model,
chatbot,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens if request.max_tokens else 512,
echo=False,
stop=request.stop,
)
if not is_openai_model(chatbot.model_name.lower()):
gen_params = await get_generation_parameters(
request.model,
chatbot,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens if request.max_tokens else 512,
echo=False,
stop=request.stop,
)
else:
gen_params = {
"prompt": request.messages,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"max_new_tokens": request.max_tokens,
}

if request.stream:
generator = chat_completion_stream_generator(
Expand All @@ -524,12 +534,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
if isinstance(content, str):
content = json.loads(content)

content_string = content["text"]
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
return create_error_response(content["error_code"], content_string)
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=content["text"]),
message=ChatMessage(role="assistant", content=content_string),
finish_reason=content.get("finish_reason", "stop"),
)
)
Expand Down
4 changes: 4 additions & 0 deletions intel_extension_for_transformers/neural_chat/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,7 @@ def is_audio_file(filename):
return True
else:
return False

def is_openai_model(model_name_or_path):
# Check https://platform.openai.com/docs/models/model-endpoint-compatibility
return any(name in model_name_or_path for name in ["gpt-4", "gpt-3.5-turbo"])

0 comments on commit de88006

Please sign in to comment.