Skip to content

Commit

Permalink
fix: better memory usage from 800+ to 500+ (langgenius#11796)
Browse files Browse the repository at this point in the history
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
  • Loading branch information
yihong0618 authored and 刘江波 committed Dec 20, 2024
1 parent 9b1316c commit 8802d8b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 26 deletions.
26 changes: 18 additions & 8 deletions api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import logging
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast

import google.auth.transport.requests
import requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Expand All @@ -19,8 +18,6 @@
MessageStreamEvent,
)
from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
Expand All @@ -47,6 +44,9 @@
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

if TYPE_CHECKING:
import vertexai.generative_models as glm

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -102,6 +102,8 @@ def _generate_anthropic(
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
from google.oauth2 import service_account

# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "")
Expand Down Expand Up @@ -406,13 +408,15 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:

return text.rstrip()

def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
"""
Convert tool messages to glm tools

:param tools: tool messages
:return: glm tools
"""
import vertexai.generative_models as glm

return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
Expand Down Expand Up @@ -473,6 +477,10 @@ def _generate(
:param user: unique user id
:return: full response or stream response chunk generator result
"""
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account

config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)

Expand Down Expand Up @@ -522,7 +530,7 @@ def _generate(
return self._handle_generate_response(model, credentials, response, prompt_messages)

def _handle_generate_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
Expand Down Expand Up @@ -554,7 +562,7 @@ def _handle_generate_response(
return result

def _handle_generate_stream_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
Expand Down Expand Up @@ -638,13 +646,15 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

return message_text

def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
"""
Format a single message into glm.Content for Google API

:param message: one PromptMessage
:return: glm Content representation of message
"""
import vertexai.generative_models as glm

if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import json
import time
from decimal import Decimal
from typing import Optional
from typing import TYPE_CHECKING, Optional

import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel

from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
Expand All @@ -24,6 +21,11 @@
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi

if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
else:
VertexTextEmbeddingModel = None


class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
Expand All @@ -48,6 +50,10 @@ def _invoke(
:param input_type: input type
:return: embeddings result
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel

service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
Expand Down Expand Up @@ -100,6 +106,10 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:param credentials: model credentials
:return:
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel

try:
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import re
from typing import Optional

import jieba
from jieba.analyse import default_tfidf

from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS


class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
import jieba.analyse

from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS

jieba.analyse.default_tfidf.stop_words = STOPWORDS

def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba

keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
Expand All @@ -22,6 +23,8 @@ def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10

def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS

results = set()
for token in tokens:
results.add(token)
Expand Down
6 changes: 4 additions & 2 deletions api/core/rag/datasource/vdb/oracle/oraclevector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from typing import Any

import jieba.posseg as pseg
import nltk
import numpy
import oracledb
from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator

from configs import dify_config
Expand Down Expand Up @@ -202,6 +200,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords

top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)
Expand Down
17 changes: 11 additions & 6 deletions api/core/workflow/nodes/document_extractor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx

from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
Expand Down Expand Up @@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:


def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt

try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
Expand All @@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:


def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx

try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
Expand All @@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:


def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub

try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
Expand All @@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:


def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email

try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
Expand All @@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:


def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg

try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)
Expand Down

0 comments on commit 8802d8b

Please sign in to comment.