Merge branch 'fix/chore-fix' of github.com:langgenius/dify into fix/chore-fix

This commit is contained in:
Yeuoly 2024-11-12 18:53:45 +08:00
commit 21fd58caf9
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
3 changed files with 24 additions and 8 deletions

View File

@ -24,6 +24,8 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(model_config, text)
moderation_result = moderation.check_moderation(
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
)
return moderation_result

View File

@ -1,27 +1,32 @@
import logging
import random
from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType
logger = logging.getLogger(__name__)
def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
if (
moderation_config
and moderation_config.enabled is True
and "openai" in hosting_configuration.provider_map
and hosting_configuration.provider_map["openai"].enabled is True
and openai_provider_name in hosting_configuration.provider_map
and hosting_configuration.provider_map[openai_provider_name].enabled is True
):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"]
hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
if hosting_openai_config.credentials is None:
return False
@ -36,9 +41,15 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
text_chunk = random.choice(text_chunks)
try:
model_type_instance = OpenAIModerationModel()
model_provider_factory = ModelProviderFactory(tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke(
model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk
model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
)
if moderation_result is True:

View File

@ -22,6 +22,9 @@ def handle(sender, **kwargs):
system_configuration = provider_configuration.system_configuration
if not system_configuration.current_quota_type:
return
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type: