From bc0724b4997e52ea9a29e41d4ef330e5f7a64387 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Nov 2024 19:50:39 +0800 Subject: [PATCH 1/2] chore: fix typo --- api/events/event_handlers/deduct_quota_when_message_created.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index 843a232096..f1479c58a4 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -22,6 +22,9 @@ def handle(sender, **kwargs): system_configuration = provider_configuration.system_configuration + if not system_configuration.current_quota_type: + return + quota_unit = None for quota_configuration in system_configuration.quota_configurations: if quota_configuration.quota_type == system_configuration.current_quota_type: From 1d2118fc5dfa6b93bf642ea62c49962aab8a42c3 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Nov 2024 20:31:11 +0800 Subject: [PATCH 2/2] fix: hosted moderation --- .../hosting_moderation/hosting_moderation.py | 4 ++- api/core/helper/moderation.py | 25 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index ba14b61201..a5a5486581 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -24,6 +24,8 @@ class HostingModerationFeature: if isinstance(prompt_message.content, str): text += prompt_message.content + "\n" - moderation_result = moderation.check_moderation(model_config, text) + moderation_result = moderation.check_moderation( + tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text + ) return moderation_result diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index f3144039e3..434f4205e8 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,27 +1,32 @@ import logging import random +from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities import DEFAULT_PLUGIN_ID +from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeBadRequestError -from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel +from core.model_runtime.model_providers.__base.moderation_model import ModerationModel +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType logger = logging.getLogger(__name__) -def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: +def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config + openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai" if ( moderation_config and moderation_config.enabled is True - and "openai" in hosting_configuration.provider_map - and hosting_configuration.provider_map["openai"].enabled is True + and openai_provider_name in hosting_configuration.provider_map + and hosting_configuration.provider_map[openai_provider_name].enabled is True ): using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type provider_name = model_config.provider if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: - hosting_openai_config = hosting_configuration.provider_map["openai"] + hosting_openai_config = hosting_configuration.provider_map[openai_provider_name] if hosting_openai_config.credentials is None: return False @@ -36,9 +41,15 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) text_chunk = random.choice(text_chunks) try: - model_type_instance = OpenAIModerationModel() + model_provider_factory = ModelProviderFactory(tenant_id) + + # Get model instance of LLM + model_type_instance = model_provider_factory.get_model_type_instance( + provider=openai_provider_name, model_type=ModelType.MODERATION + ) + model_type_instance = cast(ModerationModel, model_type_instance) moderation_result = model_type_instance.invoke( - model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk + model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk ) if moderation_result is True: