Skip to content

Commit

Permalink
refactor: Add an enumeration type and use the factory pattern to obta…
Browse files Browse the repository at this point in the history
…in the corresponding class (langgenius#9356)
  • Loading branch information
hwzhuhao authored and lau-td committed Oct 23, 2024
1 parent f9d074e commit 781b516
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 19 deletions.
25 changes: 14 additions & 11 deletions api/core/rag/datasource/keyword/keyword_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any

from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.datasource.keyword.keyword_type import KeyWordType
from core.rag.models.document import Document
from models.dataset import Dataset

Expand All @@ -13,16 +13,19 @@ def __init__(self, dataset: Dataset):
self._keyword_processor = self._init_keyword()

def _init_keyword(self) -> BaseKeyword:
config = dify_config
keyword_type = config.KEYWORD_STORE

if not keyword_type:
raise ValueError("Keyword store must be specified.")

if keyword_type == "jieba":
return Jieba(dataset=self._dataset)
else:
raise ValueError(f"Keyword store {keyword_type} is not supported.")
keyword_type = dify_config.KEYWORD_STORE
keyword_factory = self.get_keyword_factory(keyword_type)
return keyword_factory(self._dataset)

@staticmethod
def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
match keyword_type:
case KeyWordType.JIEBA:
from core.rag.datasource.keyword.jieba.jieba import Jieba

return Jieba
case _:
raise ValueError(f"Keyword store {keyword_type} is not supported.")

def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions api/core/rag/datasource/keyword/keyword_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from enum import Enum


class KeyWordType(str, Enum):
JIEBA = "jieba"
26 changes: 18 additions & 8 deletions api/services/auth/api_key_auth_factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from services.auth.firecrawl import FirecrawlAuth
from services.auth.jina import JinaAuth
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.auth_type import AuthType


class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == "firecrawl":
self.auth = FirecrawlAuth(credentials)
elif provider == "jinareader":
self.auth = JinaAuth(credentials)
else:
raise ValueError("Invalid provider")
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

def validate_credentials(self):
return self.auth.validate_credentials()

@staticmethod
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
match provider:
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth

return FirecrawlAuth
case AuthType.JINA:
from services.auth.jina.jina import JinaAuth

return JinaAuth
case _:
raise ValueError("Invalid provider")
6 changes: 6 additions & 0 deletions api/services/auth/auth_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class AuthType(str, Enum):
FIRECRAWL = "firecrawl"
JINA = "jinareader"
Empty file.
File renamed without changes.
Empty file.
File renamed without changes.

0 comments on commit 781b516

Please sign in to comment.