2024-10-15 12:51:13 +08:00
|
|
|
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
|
|
|
from services.auth.auth_type import AuthType
|
2024-06-15 02:46:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
class ApiKeyAuthFactory:
|
|
|
|
def __init__(self, provider: str, credentials: dict):
|
2024-10-15 12:51:13 +08:00
|
|
|
auth_factory = self.get_apikey_auth_factory(provider)
|
|
|
|
self.auth = auth_factory(credentials)
|
2024-06-15 02:46:02 +08:00
|
|
|
|
|
|
|
def validate_credentials(self):
|
|
|
|
return self.auth.validate_credentials()
|
2024-10-15 12:51:13 +08:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
|
|
|
|
match provider:
|
|
|
|
case AuthType.FIRECRAWL:
|
|
|
|
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
|
|
|
|
|
|
|
return FirecrawlAuth
|
|
|
|
case AuthType.JINA:
|
|
|
|
from services.auth.jina.jina import JinaAuth
|
|
|
|
|
|
|
|
return JinaAuth
|
|
|
|
case _:
|
|
|
|
raise ValueError("Invalid provider")
|