import datetime import json import logging from collections import defaultdict from collections.abc import Iterator from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, ModelSettings, SystemConfiguration, SystemConfigurationStatus, ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import FetchFrom, ModelType from core.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider from extensions.ext_database import db from models.provider import ( LoadBalancingModelConfig, Provider, ProviderModel, ProviderModelSetting, ProviderType, TenantPreferredModelProvider, ) logger = logging.getLogger(__name__) original_provider_configurate_methods = {} class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType using_provider_type: ProviderType system_configuration: SystemConfiguration custom_configuration: CustomConfiguration model_settings: list[ModelSettings] # pydantic configs model_config = ConfigDict(protected_namespaces=()) def __init__(self, **data): super().__init__(**data) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: if ( any( len(quota_configuration.restrict_models) > 0 for quota_configuration in self.system_configuration.quota_configurations ) and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: """ Get current credentials. :param model_type: model type :param model: model name :return: """ if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue restrict_models = quota_configuration.restrict_models copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: if ( restrict_model.model_type == model_type and restrict_model.model == model and restrict_model.base_model_name ): copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: credentials = None if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials break if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials return credentials def get_system_configuration_status(self) -> SystemConfigurationStatus: """ Get system configuration status. :return: """ if self.system_configuration.enabled is False: return SystemConfigurationStatus.UNSUPPORTED current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) return ( SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else SystemConfigurationStatus.QUOTA_EXCEEDED ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ Get custom credentials. :param obfuscated: obfuscated secret data in credentials :return: """ if self.custom_configuration.provider is None: return None credentials = self.custom_configuration.provider.credentials if not obfuscated: return credentials # Obfuscate credentials return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [], ) def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: """ Validate custom credentials. :param credentials: provider credentials :return: """ # get provider provider_record = ( db.session.query(Provider) .filter( Provider.tenant_id == self.tenant_id, Provider.provider_name == self.provider.provider, Provider.provider_type == ProviderType.CUSTOM.value, ) .first() ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) if provider_record: try: # fix origin data if provider_record.encrypted_config: if not provider_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": provider_record.encrypted_config} else: original_credentials = json.loads(provider_record.encrypted_config) else: original_credentials = {} except JSONDecodeError: original_credentials = {} # encrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) for key, value in credentials.items(): if key in provider_credential_secret_variables: credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return provider_record, credentials def add_or_update_custom_credentials(self, credentials: dict) -> None: """ Add or update custom provider credentials. :param credentials: :return: """ # validate custom provider config provider_record, credentials = self.custom_credentials_validate(credentials) # save provider # Note: Do not switch the preferred provider, which allows users to use quotas first if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(credentials), is_valid=True, ) db.session.add(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(ProviderType.CUSTOM) def delete_custom_credentials(self) -> None: """ Delete custom provider credentials. :return: """ # get provider provider_record = ( db.session.query(Provider) .filter( Provider.tenant_id == self.tenant_id, Provider.provider_name == self.provider.provider, Provider.provider_type == ProviderType.CUSTOM.value, ) .first() ) # delete provider if provider_record: self.switch_preferred_provider_type(ProviderType.SYSTEM) db.session.delete(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() def get_custom_model_credentials( self, model_type: ModelType, model: str, obfuscated: bool = False ) -> Optional[dict]: """ Get custom model credentials. :param model_type: model type :param model: model name :param obfuscated: obfuscated secret data in credentials :return: """ if not self.custom_configuration.models: return None for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials if not obfuscated: return credentials # Obfuscate credentials return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) return None def custom_model_credentials_validate( self, model_type: ModelType, model: str, credentials: dict ) -> tuple[ProviderModel, dict]: """ Validate custom model credentials. :param model_type: model type :param model: model name :param credentials: model credentials :return: """ # get provider model provider_model_record = ( db.session.query(ProviderModel) .filter( ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name == self.provider.provider, ProviderModel.model_name == model, ProviderModel.model_type == model_type.to_origin_model_type(), ) .first() ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) if provider_model_record: try: original_credentials = ( json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} ) except JSONDecodeError: original_credentials = {} # decrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in credentials.items(): if key in provider_credential_secret_variables: credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return provider_model_record, credentials def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: """ Add or update custom model credentials. :param model_type: model type :param model: model name :param credentials: model credentials :return: """ # validate custom model config provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) # save provider model # Note: Do not switch the preferred provider, which allows users to use quotas first if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: provider_model_record = ProviderModel( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), is_valid=True, ) db.session.add(provider_model_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: """ Delete custom model credentials. :param model_type: model type :param model: model name :return: """ # get provider model provider_model_record = ( db.session.query(ProviderModel) .filter( ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name == self.provider.provider, ProviderModel.model_name == model, ProviderModel.model_type == model_type.to_origin_model_type(), ) .first() ) # delete provider model if provider_model_record: db.session.delete(provider_model_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model. :param model_type: model type :param model: model name :return: """ model_setting = ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) .first() ) if model_setting: model_setting.enabled = True model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, enabled=True, ) db.session.add(model_setting) db.session.commit() return model_setting def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Disable model. :param model_type: model type :param model: model name :return: """ model_setting = ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) .first() ) if model_setting: model_setting.enabled = False model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, enabled=False, ) db.session.add(model_setting) db.session.commit() return model_setting def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: """ Get provider model setting. :param model_type: model type :param model: model name :return: """ return ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) .first() ) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model load balancing. :param model_type: model type :param model: model name :return: """ load_balancing_config_count = ( db.session.query(LoadBalancingModelConfig) .filter( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .count() ) if load_balancing_config_count <= 1: raise ValueError("Model load balancing configuration must be more than 1.") model_setting = ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) .first() ) if model_setting: model_setting.load_balancing_enabled = True model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, load_balancing_enabled=True, ) db.session.add(model_setting) db.session.commit() return model_setting def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Disable model load balancing. :param model_type: model type :param model: model name :return: """ model_setting = ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name == self.provider.provider, ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) .first() ) if model_setting: model_setting.load_balancing_enabled = False model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, load_balancing_enabled=False, ) db.session.add(model_setting) db.session.commit() return model_setting def get_provider_instance(self) -> ModelProvider: """ Get provider instance. :return: """ return model_provider_factory.get_provider_instance(self.provider.provider) def get_model_type_instance(self, model_type: ModelType) -> AIModel: """ Get current model type instance. :param model_type: model type :return: """ # Get provider instance provider_instance = self.get_provider_instance() # Get model instance of LLM return provider_instance.get_model_instance(model_type) def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: """ Switch preferred provider type. :param provider_type: :return: """ if provider_type == self.preferred_provider_type: return if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return # get preferred provider preferred_model_provider = ( db.session.query(TenantPreferredModelProvider) .filter( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name == self.provider.provider, ) .first() ) if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, preferred_provider_type=provider_type.value, ) db.session.add(preferred_model_provider) db.session.commit() def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. :param credential_form_schemas: :return: """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: if credential_form_schema.type == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: """ Obfuscated credentials. :param credentials: credentials :param credential_form_schemas: credential form schemas :return: """ # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() for key, value in copy_credentials.items(): if key in credential_secret_variables: copy_credentials[key] = encrypter.obfuscated_token(value) return copy_credentials def get_provider_model( self, model_type: ModelType, model: str, only_active: bool = False ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type :param model: model name :param only_active: return active model only :return: """ provider_models = self.get_provider_models(model_type, only_active) for provider_model in provider_models: if provider_model.model == model: return provider_model return None def get_provider_models( self, model_type: Optional[ModelType] = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type :param only_active: only active models :return: """ provider_instance = self.get_provider_instance() model_types = [] if model_type: model_types.append(model_type) else: model_types = provider_instance.get_provider_schema().supported_model_types # Group model settings by model type and model model_setting_map = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) if only_active: provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) def _get_system_provider_models( self, model_types: list[ModelType], provider_instance: ModelProvider, model_setting_map: dict[ModelType, dict[str, ModelSettings]], ) -> list[ModelWithProviderEntity]: """ Get system provider models. :param model_types: model types :param provider_instance: provider instance :param model_setting_map: model setting map :return: """ provider_models = [] for model_type in model_types: for m in provider_instance.models(model_type): status = ModelStatus.ACTIVE if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, model_type=m.model_type, features=m.features, fetch_from=m.fetch_from, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, ) ) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in provider_instance.get_provider_schema().configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) should_use_custom_model = False if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: should_use_custom_model = True for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue restrict_models = quota_configuration.restrict_models if len(restrict_models) == 0: break if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ ConfigurateMethod.CUSTOMIZABLE_MODEL ]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() if restrict_model.base_model_name: copy_credentials["base_model_name"] = restrict_model.base_model_name try: custom_model_schema = provider_instance.get_model_instance( restrict_model.model_type ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) except Exception as ex: logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: continue if custom_model_schema.model_type not in model_types: continue status = ModelStatus.ACTIVE if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, label=custom_model_schema.label, model_type=custom_model_schema.model_type, features=custom_model_schema.features, fetch_from=FetchFrom.PREDEFINED_MODEL, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, ) ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] for m in provider_models: if m.model_type == ModelType.LLM and m.model not in restrict_model_names: m.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: m.status = ModelStatus.QUOTA_EXCEEDED return provider_models def _get_custom_provider_models( self, model_types: list[ModelType], provider_instance: ModelProvider, model_setting_map: dict[ModelType, dict[str, ModelSettings]], ) -> list[ModelWithProviderEntity]: """ Get custom provider models. :param model_types: model types :param provider_instance: provider instance :param model_setting_map: model setting map :return: """ provider_models = [] credentials = None if self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials for model_type in model_types: if model_type not in self.provider.supported_model_types: continue models = provider_instance.models(model_type) for m in models: status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED if len(model_setting.load_balancing_configs) > 1: load_balancing_enabled = True provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, model_type=m.model_type, features=m.features, fetch_from=m.fetch_from, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, ) ) # custom models for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue try: custom_model_schema = provider_instance.get_model_instance( model_configuration.model_type ).get_customizable_model_schema_from_credentials( model_configuration.model, model_configuration.credentials ) except Exception as ex: logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: continue status = ModelStatus.ACTIVE load_balancing_enabled = False if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED if len(model_setting.load_balancing_configs) > 1: load_balancing_enabled = True provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, label=custom_model_schema.label, model_type=custom_model_schema.model_type, features=custom_model_schema.features, fetch_from=custom_model_schema.fetch_from, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, ) ) return provider_models class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ tenant_id: str configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) def get_models( self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get available models. If preferred provider type is `system`: Get the current **system mode** if provider supported, if all system modes are not available (no quota), it is considered to be the **custom credential mode**. If there is no model configured in custom mode, it is treated as no_configure. system > custom > no_configure If preferred provider type is `custom`: If custom credentials are configured, it is treated as custom mode. Otherwise, get the current **system mode** if supported, If all system modes are not available (no quota), it is treated as no_configure. custom > system > no_configure If real mode is `system`, use system credentials to get models, paid quotas > provider free quotas > system free quotas include pre-defined models (exclude GPT-4, status marked as `no_permission`). If real mode is `custom`, use workspace custom credentials to get models, include pre-defined models, custom models(manual append). If real mode is `no_configure`, only return pre-defined models from `model runtime`. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) model status marked as `active` is available. :param provider: provider name :param model_type: model type :param only_active: only active models :return: """ all_models = [] for provider_configuration in self.values(): if provider and provider_configuration.provider.provider != provider: continue all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) return all_models def to_list(self) -> list[ProviderConfiguration]: """ Convert to list. :return: """ return list(self.values()) def __getitem__(self, key): return self.configurations[key] def __setitem__(self, key, value): self.configurations[key] = value def __iter__(self): return iter(self.configurations) def values(self) -> Iterator[ProviderConfiguration]: return self.configurations.values() def get(self, key, default=None): return self.configurations.get(key, default) class ProviderModelBundle(BaseModel): """ Provider model bundle. """ configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel # pydantic configs model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())