Skip to content

Commit

Permalink
refactor(knowledge-retrieval): improve error handling with custom exc…
Browse files Browse the repository at this point in the history
…eptions (langgenius#10385)
  • Loading branch information
laipz8200 authored and JunXu01 committed Nov 9, 2024
1 parent 9a6da66 commit 730a59f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
18 changes: 18 additions & 0 deletions api/core/workflow/nodes/knowledge_retrieval/exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class KnowledgeRetrievalNodeError(ValueError):
"""Base class for KnowledgeRetrievalNode errors."""


class ModelNotExistError(KnowledgeRetrievalNodeError):
"""Raised when the model does not exist."""


class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
"""Raised when the model credentials are not initialized."""


class ModelNotSupportedError(KnowledgeRetrievalNodeError):
"""Raised when the model is not supported."""


class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
Expand All @@ -18,11 +17,19 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus

from .entities import KnowledgeRetrievalNodeData
from .exc import (
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)

logger = logging.getLogger(__name__)

default_retrieval_model = {
Expand Down Expand Up @@ -61,8 +68,8 @@ def _run(self) -> NodeRunResult:
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)

except Exception as e:
logger.exception("Error when running knowledge retrieval node")
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))

def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -295,14 +302,14 @@ def _fetch_model_config(
)

if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")

# model config
completion_params = node_data.single_retrieval_config.model.completion_params
Expand All @@ -314,12 +321,12 @@ def _fetch_model_config(
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
raise ModelNotExistError("LLM mode is required.")

model_schema = model_type_instance.get_model_schema(model_name, model_credentials)

if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")

return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
Expand Down

0 comments on commit 730a59f

Please sign in to comment.