Skip to content

Commit

Permalink
Support Chinese for Docsum (#799)
Browse files Browse the repository at this point in the history
* support Chinese for DocSum

Signed-off-by: Xinyao Wang <xinyao.wang@intel.com>

* refine readme for docsum

Signed-off-by: Xinyao Wang <xinyao.wang@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Xinyao Wang <xinyao.wang@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: chen, suyue <suyue.chen@intel.com>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent 2710115 commit 9a00a3e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 39 deletions.
1 change: 1 addition & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ async def handle_request(self, request: Request):
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
language=chat_request.language if chat_request.language else "auto",
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
Expand Down
1 change: 1 addition & 0 deletions comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none"
parallel_tool_calls: Optional[bool] = True
user: Optional[str] = None
language: str = "auto" # can be "en", "zh"

# Ordered by official OpenAI API documentation
# default values are same with
Expand Down
2 changes: 2 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class LLMParamsDoc(BaseDoc):
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True
language: str = "auto" # can be "en", "zh"

chat_template: Optional[str] = Field(
default=None,
Expand Down Expand Up @@ -212,6 +213,7 @@ class LLMParams(BaseDoc):
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True
language: str = "auto" # can be "en", "zh"

chat_template: Optional[str] = Field(
default=None,
Expand Down
10 changes: 8 additions & 2 deletions comps/llms/summarization/tgi/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,18 @@ curl http://${your_ip}:9000/v1/health_check\
# Enable streaming to receive a streaming response. By default, this is set to True.
curl http://${your_ip}:9000/v1/chat/docsum \
-X POST \
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."}' \
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en"}' \
-H 'Content-Type: application/json'

# Disable streaming to receive a non-streaming response.
curl http://${your_ip}:9000/v1/chat/docsum \
-X POST \
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "streaming":false}' \
-d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "streaming":false}' \
-H 'Content-Type: application/json'

# Use Chinese mode. By default, language is set to "en"
curl http://${your_ip}:9000/v1/chat/docsum \
-X POST \
-d '{"query":"2024年9月26日,北京——今日,英特尔正式发布英特尔® 至强® 6性能核处理器(代号Granite Rapids),为AI、数据分析、科学计算等计算密集型业务提供卓越性能。", "max_tokens":32, "language":"zh", "streaming":false}' \
-H 'Content-Type: application/json'
```
88 changes: 51 additions & 37 deletions comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,35 @@
import os

from fastapi.responses import StreamingResponse
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain_huggingface import HuggingFaceEndpoint
from huggingface_hub import AsyncInferenceClient
from langchain.prompts import PromptTemplate

from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice

logger = CustomLogger("llm_docsum")
logflag = os.getenv("LOGFLAG", False)

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(
model=llm_endpoint,
timeout=600,
)

templ_en = """Write a concise summary of the following:
"{text}"
CONCISE SUMMARY:"""

def post_process_text(text: str):
if text == " ":
return "data: @#$\n\n"
if text == "\n":
return "data: <br/>\n\n"
if text.isspace():
return None
new_text = text.replace(" ", "@#$")
return f"data: {new_text}\n\n"
templ_zh = """请简要概括以下内容:
"{text}"
概况:"""


@register_microservice(
Expand All @@ -37,46 +46,51 @@ async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)

llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
if input.language in ["en", "auto"]:
templ = templ_en
elif input.language in ["zh"]:
templ = templ_zh
else:
raise NotImplementedError('Please specify the input language in "en", "zh", "auto"')

prompt_template = PromptTemplate.from_template(templ)
prompt = prompt_template.format(text=input.query)

if logflag:
logger.info("After prompting:")
logger.info(prompt)

text_generation = await llm.text_generation(
prompt=prompt,
stream=input.streaming,
max_new_tokens=input.max_tokens,
repetition_penalty=input.repetition_penalty,
temperature=input.temperature,
top_k=input.top_k,
top_p=input.top_p,
typical_p=input.typical_p,
temperature=input.temperature,
repetition_penalty=input.repetition_penalty,
streaming=input.streaming,
)
llm_chain = load_summarize_chain(llm=llm, chain_type="map_reduce")
texts = text_splitter.split_text(input.query)

# Create multiple documents
docs = [Document(page_content=t) for t in texts]

if input.streaming:

async def stream_generator():
from langserve.serialization import WellKnownLCSerializer

_serializer = WellKnownLCSerializer()
async for chunk in llm_chain.astream_log(docs):
data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8")
chat_response = ""
async for text in text_generation:
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[docsum - text_summarize] data: {data}")
yield f"data: {data}\n\n"
logger.info(f"[ docsum - text_summarize ] chunk:{chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ docsum - text_summarize ] stream response: {chat_response}")
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = await llm_chain.ainvoke(docs)
response = response["output_text"]
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)
logger.info(text_generation)
return GeneratedDoc(text=text_generation, prompt=input.query)


if __name__ == "__main__":
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
# Split text
text_splitter = CharacterTextSplitter()
opea_microservices["opea_service@llm_docsum"].start()

0 comments on commit 9a00a3e

Please sign in to comment.