2024-06-15 02:46:02 +08:00
|
|
|
import json
|
|
|
|
|
|
|
|
from core.helper import encrypter
|
|
|
|
from extensions.ext_database import db
|
|
|
|
from models.source import DataSourceApiKeyAuthBinding
|
|
|
|
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
|
|
|
|
|
|
|
|
|
|
|
class ApiKeyAuthService:
|
|
|
|
@staticmethod
|
|
|
|
def get_provider_auth_list(tenant_id: str) -> list:
|
2024-08-26 13:43:57 +08:00
|
|
|
data_source_api_key_bindings = (
|
|
|
|
db.session.query(DataSourceApiKeyAuthBinding)
|
|
|
|
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
|
|
|
.all()
|
|
|
|
)
|
2024-06-15 02:46:02 +08:00
|
|
|
return data_source_api_key_bindings
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def create_provider_auth(tenant_id: str, args: dict):
|
2024-08-26 13:43:57 +08:00
|
|
|
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
2024-06-15 02:46:02 +08:00
|
|
|
if auth_result:
|
|
|
|
# Encrypt the api key
|
2024-08-26 13:43:57 +08:00
|
|
|
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
|
|
|
args["credentials"]["config"]["api_key"] = api_key
|
2024-06-15 02:46:02 +08:00
|
|
|
|
|
|
|
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
|
|
|
data_source_api_key_binding.tenant_id = tenant_id
|
2024-08-26 13:43:57 +08:00
|
|
|
data_source_api_key_binding.category = args["category"]
|
|
|
|
data_source_api_key_binding.provider = args["provider"]
|
|
|
|
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
2024-06-15 02:46:02 +08:00
|
|
|
db.session.add(data_source_api_key_binding)
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
2024-08-26 13:43:57 +08:00
|
|
|
data_source_api_key_bindings = (
|
|
|
|
db.session.query(DataSourceApiKeyAuthBinding)
|
|
|
|
.filter(
|
|
|
|
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
|
|
|
DataSourceApiKeyAuthBinding.category == category,
|
|
|
|
DataSourceApiKeyAuthBinding.provider == provider,
|
|
|
|
DataSourceApiKeyAuthBinding.disabled.is_(False),
|
|
|
|
)
|
|
|
|
.first()
|
|
|
|
)
|
2024-06-15 02:46:02 +08:00
|
|
|
if not data_source_api_key_bindings:
|
|
|
|
return None
|
|
|
|
credentials = json.loads(data_source_api_key_bindings.credentials)
|
|
|
|
return credentials
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def delete_provider_auth(tenant_id: str, binding_id: str):
|
2024-08-26 13:43:57 +08:00
|
|
|
data_source_api_key_binding = (
|
|
|
|
db.session.query(DataSourceApiKeyAuthBinding)
|
|
|
|
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
|
|
|
.first()
|
|
|
|
)
|
2024-06-15 02:46:02 +08:00
|
|
|
if data_source_api_key_binding:
|
|
|
|
db.session.delete(data_source_api_key_binding)
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def validate_api_key_auth_args(cls, args):
|
2024-08-26 13:43:57 +08:00
|
|
|
if "category" not in args or not args["category"]:
|
|
|
|
raise ValueError("category is required")
|
|
|
|
if "provider" not in args or not args["provider"]:
|
|
|
|
raise ValueError("provider is required")
|
|
|
|
if "credentials" not in args or not args["credentials"]:
|
|
|
|
raise ValueError("credentials is required")
|
|
|
|
if not isinstance(args["credentials"], dict):
|
|
|
|
raise ValueError("credentials must be a dictionary")
|
|
|
|
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
|
|
|
raise ValueError("auth_type is required")
|