import datetime import json import logging import os from collections import defaultdict from typing import Optional import requests from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \ TenantDefaultModel class ProviderService: def get_provider_list(self, tenant_id: str): """ get provider list of tenant. :param tenant_id: :return: """ # get rules for all providers model_provider_rules = ModelProviderFactory.get_provider_rules() model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()] for model_provider_name, model_provider_rule in model_provider_rules.items(): if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \ and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \ and 'supported_quota_types' in model_provider_rule['system_config'] \ and 'trial' in model_provider_rule['system_config']['supported_quota_types']: ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) configurable_model_provider_names = [ model_provider_name for model_provider_name, model_provider_rules in model_provider_rules.items() if 'custom' in model_provider_rules['support_provider_types'] and model_provider_rules['model_flexibility'] == 'configurable' ] # get all providers for the tenant providers = db.session.query(Provider) \ .filter( Provider.tenant_id == tenant_id, Provider.provider_name.in_(model_provider_names), Provider.is_valid == True ).order_by(Provider.created_at.desc()).all() provider_name_to_provider_dict = defaultdict(list) for provider in providers: provider_name_to_provider_dict[provider.provider_name].append(provider) # get all configurable provider models for the tenant provider_models = db.session.query(ProviderModel) \ .filter( ProviderModel.tenant_id == tenant_id, ProviderModel.provider_name.in_(configurable_model_provider_names), ProviderModel.is_valid == True ).order_by(ProviderModel.created_at.desc()).all() provider_name_to_provider_model_dict = defaultdict(list) for provider_model in provider_models: provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model) # get all preferred provider type for the tenant preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ .filter( TenantPreferredModelProvider.tenant_id == tenant_id, TenantPreferredModelProvider.provider_name.in_(model_provider_names) ).all() provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type for preferred_provider_type in preferred_provider_types} providers_list = {} for model_provider_name, model_provider_rule in model_provider_rules.items(): # get preferred provider type preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name) preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider( tenant_id, model_provider_name, preferred_model_provider ) provider_config_dict = { "preferred_provider_type": preferred_provider_type, "model_flexibility": model_provider_rule['model_flexibility'], } provider_parameter_dict = {} if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']: for quota_type_enum in ProviderQuotaType: quota_type = quota_type_enum.value if quota_type in model_provider_rule['system_config']['supported_quota_types']: key = ProviderType.SYSTEM.value + ':' + quota_type provider_parameter_dict[key] = { "provider_name": model_provider_name, "provider_type": ProviderType.SYSTEM.value, "config": None, "is_valid": False, # need update "quota_type": quota_type, "quota_unit": model_provider_rule['system_config']['quota_unit'], # need update "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else model_provider_rule['system_config']['quota_limit'], # need update "quota_used": 0, # need update "last_used": None # need update } if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']: provider_parameter_dict[ProviderType.CUSTOM.value] = { "provider_name": model_provider_name, "provider_type": ProviderType.CUSTOM.value, "config": None, # need update "models": [], # need update "is_valid": False, "last_used": None # need update } model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name) current_providers = provider_name_to_provider_dict[model_provider_name] for provider in current_providers: if provider.provider_type == ProviderType.SYSTEM.value: quota_type = provider.quota_type key = f'{ProviderType.SYSTEM.value}:{quota_type}' if key in provider_parameter_dict: provider_parameter_dict[key]['is_valid'] = provider.is_valid provider_parameter_dict[key]['quota_used'] = provider.quota_used provider_parameter_dict[key]['quota_limit'] = provider.quota_limit provider_parameter_dict[key]['last_used'] = provider.last_used elif provider.provider_type == ProviderType.CUSTOM.value \ and ProviderType.CUSTOM.value in provider_parameter_dict: # if custom key = ProviderType.CUSTOM.value provider_parameter_dict[key]['last_used'] = provider.last_used provider_parameter_dict[key]['is_valid'] = provider.is_valid if model_provider_rule['model_flexibility'] == 'fixed': provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \ .get_provider_credentials(obfuscated=True) else: models = [] provider_models = provider_name_to_provider_model_dict[model_provider_name] for provider_model in provider_models: models.append({ "model_name": provider_model.model_name, "model_type": provider_model.model_type, "config": model_provider_class(provider=provider) \ .get_model_credentials(provider_model.model_name, ModelType.value_of(provider_model.model_type), obfuscated=True), "is_valid": provider_model.is_valid }) provider_parameter_dict[key]['models'] = models provider_config_dict['providers'] = list(provider_parameter_dict.values()) providers_list[model_provider_name] = provider_config_dict return providers_list def custom_provider_config_validate(self, provider_name: str, config: dict) -> None: """ validate custom provider config. :param provider_name: :param config: :return: :raises CredentialsValidateFailedError: When the config credential verification fails. """ # get model provider rules model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name) if model_provider_rules['model_flexibility'] != 'fixed': raise ValueError('Only support fixed model provider') # only support provider type CUSTOM if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']: raise ValueError('Only support provider type CUSTOM') # validate provider config model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) model_provider_class.is_provider_credentials_valid_or_raise(config) def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None: """ save custom provider config. :param tenant_id: :param provider_name: :param config: :return: """ # validate custom provider config self.custom_provider_config_validate(provider_name, config) # get provider provider = db.session.query(Provider) \ .filter( Provider.tenant_id == tenant_id, Provider.provider_name == provider_name, Provider.provider_type == ProviderType.CUSTOM.value ).first() model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config) # save provider if provider: provider.encrypted_config = json.dumps(encrypted_config) provider.is_valid = True provider.updated_at = datetime.datetime.utcnow() db.session.commit() else: provider = Provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(encrypted_config), is_valid=True ) db.session.add(provider) db.session.commit() def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None: """ delete custom provider. :param tenant_id: :param provider_name: :return: """ # get provider provider = db.session.query(Provider) \ .filter( Provider.tenant_id == tenant_id, Provider.provider_name == provider_name, Provider.provider_type == ProviderType.CUSTOM.value ).first() if provider: try: self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value) except ValueError: pass db.session.delete(provider) db.session.commit() def custom_provider_model_config_validate(self, provider_name: str, model_name: str, model_type: str, config: dict) -> None: """ validate custom provider model config. :param provider_name: :param model_name: :param model_type: :param config: :return: :raises CredentialsValidateFailedError: When the config credential verification fails. """ # get model provider rules model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name) if model_provider_rules['model_flexibility'] != 'configurable': raise ValueError('Only support configurable model provider') # only support provider type CUSTOM if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']: raise ValueError('Only support provider type CUSTOM') # validate provider model config model_type = ModelType.value_of(model_type) model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config) def add_or_save_custom_provider_model_config(self, tenant_id: str, provider_name: str, model_name: str, model_type: str, config: dict) -> None: """ Add or save custom provider model config. :param tenant_id: :param provider_name: :param model_name: :param model_type: :param config: :return: """ # validate custom provider model config self.custom_provider_model_config_validate(provider_name, model_name, model_type, config) # get provider provider = db.session.query(Provider) \ .filter( Provider.tenant_id == tenant_id, Provider.provider_name == provider_name, Provider.provider_type == ProviderType.CUSTOM.value ).first() if not provider: provider = Provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=ProviderType.CUSTOM.value, is_valid=True ) db.session.add(provider) db.session.commit() elif not provider.is_valid: provider.is_valid = True provider.encrypted_config = None db.session.commit() model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) encrypted_config = model_provider_class.encrypt_model_credentials( tenant_id, model_name, ModelType.value_of(model_type), config ) # get provider model provider_model = db.session.query(ProviderModel) \ .filter( ProviderModel.tenant_id == tenant_id, ProviderModel.provider_name == provider_name, ProviderModel.model_name == model_name, ProviderModel.model_type == model_type ).first() if provider_model: provider_model.encrypted_config = json.dumps(encrypted_config) provider_model.is_valid = True db.session.commit() else: provider_model = ProviderModel( tenant_id=tenant_id, provider_name=provider_name, model_name=model_name, model_type=model_type, encrypted_config=json.dumps(encrypted_config), is_valid=True ) db.session.add(provider_model) db.session.commit() def delete_custom_provider_model(self, tenant_id: str, provider_name: str, model_name: str, model_type: str) -> None: """ delete custom provider model. :param tenant_id: :param provider_name: :param model_name: :param model_type: :return: """ # get provider model provider_model = db.session.query(ProviderModel) \ .filter( ProviderModel.tenant_id == tenant_id, ProviderModel.provider_name == provider_name, ProviderModel.model_name == model_name, ProviderModel.model_type == model_type ).first() if provider_model: db.session.delete(provider_model) db.session.commit() def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None: """ switch preferred provider. :param tenant_id: :param provider_name: :param preferred_provider_type: :return: """ provider_type = ProviderType.value_of(preferred_provider_type) if not provider_type: raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}') model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name) if preferred_provider_type not in model_provider_rules['support_provider_types']: raise ValueError(f'Not support provider type: {preferred_provider_type}') model_provider = ModelProviderFactory.get_model_provider_class(provider_name) if not model_provider.is_provider_type_system_supported(): return # get preferred provider preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ .filter( TenantPreferredModelProvider.tenant_id == tenant_id, TenantPreferredModelProvider.provider_name == provider_name ).first() if preferred_model_provider: preferred_model_provider.preferred_provider_type = preferred_provider_type else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=tenant_id, provider_name=provider_name, preferred_provider_type=preferred_provider_type ) db.session.add(preferred_model_provider) db.session.commit() def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]: """ get default model of model type. :param tenant_id: :param model_type: :return: """ return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type)) def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider_name: str, model_name: str) -> TenantDefaultModel: """ update default model of model type. :param tenant_id: :param model_type: :param provider_name: :param model_name: :return: """ return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name) def get_valid_model_list(self, tenant_id: str, model_type: str) -> list: """ get valid model list. :param tenant_id: :param model_type: :return: """ valid_model_list = [] # get model provider rules model_provider_rules = ModelProviderFactory.get_provider_rules() for model_provider_name, model_provider_rule in model_provider_rules.items(): model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) if not model_provider: continue model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type)) provider = model_provider.provider for model in model_list: valid_model_dict = { "model_name": model['id'], "model_type": model_type, "model_provider": { "provider_name": provider.provider_name, "provider_type": provider.provider_type }, 'features': [] } if 'features' in model: valid_model_dict['features'] = model['features'] if provider.provider_type == ProviderType.SYSTEM.value: valid_model_dict['model_provider']['quota_type'] = provider.quota_type valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit'] valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit valid_model_dict['model_provider']['quota_used'] = provider.quota_used valid_model_list.append(valid_model_dict) return valid_model_list def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \ -> ModelKwargsRules: """ get model parameter rules. It depends on preferred provider in use. :param tenant_id: :param model_provider_name: :param model_name: :param model_type: :return: """ # get model provider model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) if not model_provider: # get empty model provider return ModelKwargsRules() # get model parameter rules return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type)) def free_quota_submit(self, tenant_id: str, provider_name: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_url = os.environ.get("FREE_QUOTA_APPLY_URL") headers = { 'Content-Type': 'application/json', 'Authorization': f"Bearer {api_key}" } response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name}) if not response.ok: logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") if response.json()["code"] != 'success': raise ValueError( f"error: {response.json()['message']}" ) rst = response.json() if rst['type'] == 'redirect': return { 'type': rst['type'], 'redirect_url': rst['redirect_url'] } else: return { 'type': rst['type'], 'result': 'success' }