diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 3c99f33be6..f1a6ade91f 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -1,8 +1,8 @@ from typing import Any from configs import dify_config -from core.rag.datasource.keyword.jieba.jieba import Jieba from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.datasource.keyword.keyword_type import KeyWordType from core.rag.models.document import Document from models.dataset import Dataset @@ -13,16 +13,19 @@ class Keyword: self._keyword_processor = self._init_keyword() def _init_keyword(self) -> BaseKeyword: - config = dify_config - keyword_type = config.KEYWORD_STORE + keyword_type = dify_config.KEYWORD_STORE + keyword_factory = self.get_keyword_factory(keyword_type) + return keyword_factory(self._dataset) - if not keyword_type: - raise ValueError("Keyword store must be specified.") + @staticmethod + def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]: + match keyword_type: + case KeyWordType.JIEBA: + from core.rag.datasource.keyword.jieba.jieba import Jieba - if keyword_type == "jieba": - return Jieba(dataset=self._dataset) - else: - raise ValueError(f"Keyword store {keyword_type} is not supported.") + return Jieba + case _: + raise ValueError(f"Keyword store {keyword_type} is not supported.") def create(self, texts: list[Document], **kwargs): self._keyword_processor.create(texts, **kwargs) diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py new file mode 100644 index 0000000000..d6deba3fb0 --- /dev/null +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -0,0 +1,5 @@ +from enum import Enum + + +class KeyWordType(str, Enum): + JIEBA = "jieba" diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 36387e9c2e..f91c448fb9 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,15 +1,25 @@ -from services.auth.firecrawl import FirecrawlAuth -from services.auth.jina import JinaAuth +from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.auth_type import AuthType class ApiKeyAuthFactory: def __init__(self, provider: str, credentials: dict): - if provider == "firecrawl": - self.auth = FirecrawlAuth(credentials) - elif provider == "jinareader": - self.auth = JinaAuth(credentials) - else: - raise ValueError("Invalid provider") + auth_factory = self.get_apikey_auth_factory(provider) + self.auth = auth_factory(credentials) def validate_credentials(self): return self.auth.validate_credentials() + + @staticmethod + def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]: + match provider: + case AuthType.FIRECRAWL: + from services.auth.firecrawl.firecrawl import FirecrawlAuth + + return FirecrawlAuth + case AuthType.JINA: + from services.auth.jina.jina import JinaAuth + + return JinaAuth + case _: + raise ValueError("Invalid provider") diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py new file mode 100644 index 0000000000..2d6e901447 --- /dev/null +++ b/api/services/auth/auth_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class AuthType(str, Enum): + FIRECRAWL = "firecrawl" + JINA = "jinareader" diff --git a/api/services/auth/firecrawl/__init__.py b/api/services/auth/firecrawl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py similarity index 100% rename from api/services/auth/firecrawl.py rename to api/services/auth/firecrawl/firecrawl.py diff --git a/api/services/auth/jina/__init__.py b/api/services/auth/jina/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/auth/jina.py b/api/services/auth/jina/jina.py similarity index 100% rename from api/services/auth/jina.py rename to api/services/auth/jina/jina.py