2024-01-02 23:42:00 +08:00
|
|
|
import logging
|
2024-06-26 17:33:29 +08:00
|
|
|
from typing import Optional
|
2024-01-02 23:42:00 +08:00
|
|
|
|
2024-04-08 18:51:46 +08:00
|
|
|
from core.app.app_config.entities import AppConfig
|
2024-01-02 23:42:00 +08:00
|
|
|
from core.moderation.base import ModerationAction, ModerationException
|
|
|
|
from core.moderation.factory import ModerationFactory
|
2024-08-09 15:22:16 +08:00
|
|
|
from core.ops.entities.trace_entity import TraceTaskName
|
|
|
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
2024-06-26 17:33:29 +08:00
|
|
|
from core.ops.utils import measure_time
|
2024-01-02 23:42:00 +08:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2024-04-08 18:51:46 +08:00
|
|
|
class InputModeration:
|
2024-06-26 17:33:29 +08:00
|
|
|
def check(
|
2024-09-10 17:00:20 +08:00
|
|
|
self,
|
|
|
|
app_id: str,
|
2024-06-26 17:33:29 +08:00
|
|
|
tenant_id: str,
|
|
|
|
app_config: AppConfig,
|
|
|
|
inputs: dict,
|
|
|
|
query: str,
|
|
|
|
message_id: str,
|
2024-09-10 17:00:20 +08:00
|
|
|
trace_manager: Optional[TraceQueueManager] = None,
|
2024-06-26 17:33:29 +08:00
|
|
|
) -> tuple[bool, dict, str]:
|
2024-01-02 23:42:00 +08:00
|
|
|
"""
|
|
|
|
Process sensitive_word_avoidance.
|
|
|
|
:param app_id: app id
|
|
|
|
:param tenant_id: tenant id
|
2024-04-08 18:51:46 +08:00
|
|
|
:param app_config: app config
|
2024-01-02 23:42:00 +08:00
|
|
|
:param inputs: inputs
|
|
|
|
:param query: query
|
2024-06-26 17:33:29 +08:00
|
|
|
:param message_id: message id
|
|
|
|
:param trace_manager: trace manager
|
2024-01-02 23:42:00 +08:00
|
|
|
:return:
|
|
|
|
"""
|
2024-04-08 18:51:46 +08:00
|
|
|
if not app_config.sensitive_word_avoidance:
|
2024-01-02 23:42:00 +08:00
|
|
|
return False, inputs, query
|
|
|
|
|
2024-04-08 18:51:46 +08:00
|
|
|
sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
|
2024-01-02 23:42:00 +08:00
|
|
|
moderation_type = sensitive_word_avoidance_config.type
|
|
|
|
|
|
|
|
moderation_factory = ModerationFactory(
|
2024-09-10 17:00:20 +08:00
|
|
|
name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
|
2024-01-02 23:42:00 +08:00
|
|
|
)
|
|
|
|
|
2024-06-26 17:33:29 +08:00
|
|
|
with measure_time() as timer:
|
|
|
|
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
2024-01-02 23:42:00 +08:00
|
|
|
|
2024-06-26 17:33:29 +08:00
|
|
|
if trace_manager:
|
|
|
|
trace_manager.add_trace_task(
|
|
|
|
TraceTask(
|
|
|
|
TraceTaskName.MODERATION_TRACE,
|
|
|
|
message_id=message_id,
|
|
|
|
moderation_result=moderation_result,
|
|
|
|
inputs=inputs,
|
2024-09-10 17:00:20 +08:00
|
|
|
timer=timer,
|
2024-06-26 17:33:29 +08:00
|
|
|
)
|
|
|
|
)
|
2024-06-28 00:24:37 +08:00
|
|
|
|
2024-01-02 23:42:00 +08:00
|
|
|
if not moderation_result.flagged:
|
|
|
|
return False, inputs, query
|
|
|
|
|
|
|
|
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
|
|
|
raise ModerationException(moderation_result.preset_response)
|
2024-09-09 23:46:13 +08:00
|
|
|
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
2024-01-02 23:42:00 +08:00
|
|
|
inputs = moderation_result.inputs
|
|
|
|
query = moderation_result.query
|
|
|
|
|
|
|
|
return True, inputs, query
|