From f9082104eda4b7cd2bd35a2822ea53344f72cc60 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Sep 2023 10:26:12 +0800 Subject: [PATCH] feat: add hosted moderation (#1158) --- api/config.py | 5 ++ api/core/agent/agent_executor.py | 18 +++++- .../agent_loop_gather_callback_handler.py | 29 ++++++--- .../callback_handler/entity/llm_message.py | 1 - .../callback_handler/llm_callback_handler.py | 9 --- .../chain/sensitive_word_avoidance_chain.py | 60 ++++++++++++++++--- api/core/completion.py | 9 +-- api/core/conversation_message_task.py | 23 +++---- api/core/helper/moderation.py | 32 ++++++++++ api/core/model_providers/model_factory.py | 3 +- api/core/model_providers/models/llm/base.py | 10 ++++ .../model_providers/models/moderation/base.py | 29 +++++++++ .../models/moderation/openai_moderation.py | 30 ++++++---- api/core/orchestrator_rule_parser.py | 44 ++++++++++---- .../moderation/test_openai_moderation.py | 7 +-- 15 files changed, 240 insertions(+), 69 deletions(-) create mode 100644 api/core/helper/moderation.py create mode 100644 api/core/model_providers/models/moderation/base.py diff --git a/api/config.py b/api/config.py index 87aee41180..694179c074 100644 --- a/api/config.py +++ b/api/config.py @@ -61,6 +61,8 @@ DEFAULTS = { 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, + 'HOSTED_MODERATION_ENABLED': 'False', + 'HOSTED_MODERATION_PROVIDERS': '', 'TENANT_DOCUMENT_COUNT': 100, 'CLEAN_DAY_SETTING': 30, 'UPLOAD_FILE_SIZE_LIMIT': 15, @@ -230,6 +232,9 @@ class Config: self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) + self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') + self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') + self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 17d8ecf6bd..903203d87b 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from langchain.agents import AgentExecutor as LCAgentExecutor +from core.helper import moderation +from core.model_providers.error import LLMError from core.model_providers.models.llm.base import BaseLLM from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -116,6 +118,18 @@ class AgentExecutor: return self.agent.should_use_agent(query) def run(self, query: str) -> AgentExecuteResult: + moderation_result = moderation.check_moderation( + self.configuration.model_instance.model_provider, + query + ) + + if not moderation_result: + return AgentExecuteResult( + output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", + strategy=self.configuration.strategy, + configuration=self.configuration + ) + agent_executor = LCAgentExecutor.from_agent_and_tools( agent=self.agent, tools=self.configuration.tools, @@ -128,7 +142,9 @@ class AgentExecutor: try: output = agent_executor.run(query) - except Exception: + except LLMError as ex: + raise ex + except Exception as ex: logging.exception("agent_executor run failed") output = None diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index c8cc043478..218a9e4e07 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional from langchain.agents import openai_functions_agent, openai_functions_multi_agent from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration +from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage from core.callback_handler.entity.agent_loop import AgentLoop from core.conversation_message_task import ConversationMessageTask @@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" raise_error: bool = True - def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: + def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: """Initialize callback handler.""" - self.model_instant = model_instant + self.model_instance = model_instance self.conversation_message_task = conversation_message_task self._agent_loops = [] self._current_loop = None @@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): """Whether to ignore chain callbacks.""" return True + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any + ) -> Any: + if not self._current_loop: + # Agent start with a LLM query + self._current_loop = AgentLoop( + position=len(self._agent_loops) + 1, + prompt="\n".join([message.content for message in messages[0]]), + status='llm_started', + started_at=time.perf_counter() + ) + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: @@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): if response.llm_output: self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] else: - self._current_loop.prompt_tokens = self.model_instant.get_num_tokens( + self._current_loop.prompt_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self._current_loop.prompt)] ) completion_generation = response.generations[0][0] @@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): if response.llm_output: self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] else: - self._current_loop.completion_tokens = self.model_instant.get_num_tokens( + self._current_loop.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self._current_loop.completion)] ) @@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self.conversation_message_task.on_agent_end( - self._message_agent_thought, self.model_instant, self._current_loop + self._message_agent_thought, self.model_instance, self._current_loop ) self._agent_loops.append(self._current_loop) @@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ) self.conversation_message_task.on_agent_end( - self._message_agent_thought, self.model_instant, self._current_loop + self._message_agent_thought, self.model_instance, self._current_loop ) self._agent_loops.append(self._current_loop) diff --git a/api/core/callback_handler/entity/llm_message.py b/api/core/callback_handler/entity/llm_message.py index 0f53295ae9..61967340c7 100644 --- a/api/core/callback_handler/entity/llm_message.py +++ b/api/core/callback_handler/entity/llm_message.py @@ -6,4 +6,3 @@ class LLMMessage(BaseModel): prompt_tokens: int = 0 completion: str = '' completion_tokens: int = 0 - latency: float = 0.0 diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 3756e8ec61..b1f3ec393e 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -1,5 +1,4 @@ import logging -import time from typing import Any, Dict, List, Union from langchain.callbacks.base import BaseCallbackHandler @@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler): messages: List[List[BaseMessage]], **kwargs: Any ) -> Any: - self.start_at = time.perf_counter() real_prompts = [] for message in messages[0]: if message.type == 'human': @@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: - self.start_at = time.perf_counter() - self.llm_message.prompt = [{ "role": 'user', "text": prompts[0] @@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler): self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - end_at = time.perf_counter() - self.llm_message.latency = end_at - self.start_at - if not self.conversation_message_task.streaming: self.conversation_message_task.append_message_text(response.generations[0][0].text) self.llm_message.completion = response.generations[0][0].text @@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler): """Do nothing.""" if isinstance(error, ConversationTaskStoppedException): if self.conversation_message_task.streaming: - end_at = time.perf_counter() - self.llm_message.latency = end_at - self.start_at self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)] ) diff --git a/api/core/chain/sensitive_word_avoidance_chain.py b/api/core/chain/sensitive_word_avoidance_chain.py index 3820840912..5fc20c5cea 100644 --- a/api/core/chain/sensitive_word_avoidance_chain.py +++ b/api/core/chain/sensitive_word_avoidance_chain.py @@ -1,15 +1,38 @@ +import enum +import logging from typing import List, Dict, Optional, Any +import openai +from flask import current_app from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain +from openai import InvalidRequestError +from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \ + AuthenticationError, OpenAIError +from pydantic import BaseModel + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.moderation import openai_moderation + + +class SensitiveWordAvoidanceRule(BaseModel): + class Type(enum.Enum): + MODERATION = "moderation" + KEYWORDS = "keywords" + + type: Type + canned_response: str = 'Your content violates our usage policy. Please revise and try again.' + extra_params: dict = {} class SensitiveWordAvoidanceChain(Chain): input_key: str = "input" #: :meta private: output_key: str = "output" #: :meta private: - sensitive_words: List[str] = [] - canned_response: str = None + model_instance: BaseLLM + sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule @property def _chain_type(self) -> str: @@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain): """ return [self.output_key] - def _check_sensitive_word(self, text: str) -> str: - for word in self.sensitive_words: + def _check_sensitive_word(self, text: str) -> bool: + for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []): if word in text: - return self.canned_response - return text + return False + return True + + def _check_moderation(self, text: str) -> bool: + moderation_model_instance = ModelFactory.get_moderation_model( + tenant_id=self.model_instance.model_provider.provider.tenant_id, + model_provider_name='openai', + model_name=openai_moderation.DEFAULT_MODEL + ) + + try: + return moderation_model_instance.run(text=text) + except Exception as ex: + logging.exception(ex) + raise LLMBadRequestError('Rate limit exceeded, please try again later.') def _call( self, @@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain): run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: text = inputs[self.input_key] - output = self._check_sensitive_word(text) - return {self.output_key: output} + + if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS: + result = self._check_sensitive_word(text) + else: + result = self._check_moderation(text) + + if not result: + raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response) + + return {self.output_key: text} diff --git a/api/core/completion.py b/api/core/completion.py index 2635b77c5d..bb2da1e8ec 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,9 +1,7 @@ import json import logging -import re -from typing import Optional, List, Union, Tuple +from typing import Optional, List, Union -from langchain.schema import BaseMessage from requests.exceptions import ChunkedEncodingError from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy @@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory from core.model_providers.model_factory import ModelFactory -from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages +from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.llm.base import BaseLLM from core.orchestrator_rule_parser import OrchestratorRuleParser from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from models.dataset import DocumentSegment, Dataset, Document from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser @@ -81,7 +78,7 @@ class Completion: # parse sensitive_word_avoidance_chain chain_callback = MainChainGatherCallbackHandler(conversation_message_task) - sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) + sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback]) if sensitive_word_avoidance_chain: query = sensitive_word_avoidance_chain.run(query) diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 8b675d2144..41db2d16c5 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -1,5 +1,5 @@ -import decimal import json +import time from typing import Optional, Union, List from core.callback_handler.entity.agent_loop import AgentLoop @@ -23,6 +23,8 @@ class ConversationMessageTask: def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False): + self.start_at = time.perf_counter() + self.task_id = task_id self.app = app @@ -61,6 +63,7 @@ class ConversationMessageTask: ) def init(self): + override_model_configs = None if self.is_override: override_model_configs = self.app_model_config.to_dict() @@ -165,7 +168,7 @@ class ConversationMessageTask: self.message.answer_tokens = answer_tokens self.message.answer_unit_price = answer_unit_price self.message.answer_price_unit = answer_price_unit - self.message.provider_response_latency = llm_message.latency + self.message.provider_response_latency = time.perf_counter() - self.start_at self.message.total_price = total_price db.session.commit() @@ -220,18 +223,18 @@ class ConversationMessageTask: return message_agent_thought - def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, + def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM, agent_loop: AgentLoop): - agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) - agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN) - agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) - agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT) + agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN) + agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN) + agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT) + agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT) loop_message_tokens = agent_loop.prompt_tokens loop_answer_tokens = agent_loop.completion_tokens - loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) - loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) + loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) + loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) loop_total_price = loop_message_total_price + loop_answer_total_price message_agent_thought.observation = agent_loop.tool_output @@ -245,7 +248,7 @@ class ConversationMessageTask: message_agent_thought.latency = agent_loop.latency message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.total_price = loop_total_price - message_agent_thought.currency = agent_model_instant.get_currency() + message_agent_thought.currency = agent_model_instance.get_currency() db.session.flush() def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py new file mode 100644 index 0000000000..e567e9ed22 --- /dev/null +++ b/api/core/helper/moderation.py @@ -0,0 +1,32 @@ +import logging + +import openai +from flask import current_app + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import BaseModelProvider +from models.provider import ProviderType + + +def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: + if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']: + moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',') + + if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ + and model_provider.provider_name in moderation_providers: + # 2000 text per chunk + length = 2000 + chunks = [text[i:i + length] for i in range(0, len(text), length)] + + try: + moderation_result = openai.Moderation.create(input=chunks, + api_key=current_app.config['HOSTED_OPENAI_API_KEY']) + except Exception as ex: + logging.exception(ex) + raise LLMBadRequestError('Rate limit exceeded, please try again later.') + + for result in moderation_result.results: + if result['flagged'] is True: + return False + + return True diff --git a/api/core/model_providers/model_factory.py b/api/core/model_providers/model_factory.py index ae5951457d..f7577b392f 100644 --- a/api/core/model_providers/model_factory.py +++ b/api/core/model_providers/model_factory.py @@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.embedding.base import BaseEmbedding from core.model_providers.models.entity.model_params import ModelKwargs, ModelType from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.moderation.base import BaseModeration from core.model_providers.models.speech2text.base import BaseSpeech2Text from extensions.ext_database import db from models.provider import TenantDefaultModel @@ -180,7 +181,7 @@ class ModelFactory: def get_moderation_model(cls, tenant_id: str, model_provider_name: str, - model_name: str) -> Optional[BaseProviderModel]: + model_name: str) -> Optional[BaseModeration]: """ get moderation model. diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 13b302cce5..2fcf7ee96e 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler +from core.helper import moderation from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules @@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel): :param callbacks: :return: """ + moderation_result = moderation.check_moderation( + self.model_provider, + "\n".join([message.content for message in messages]) + ) + + if not moderation_result: + kwargs['fake_response'] = "I apologize for any confusion, " \ + "but I'm an AI assistant to be helpful, harmless, and honest." + if self.deduct_quota: self.model_provider.check_quota_over_limit() diff --git a/api/core/model_providers/models/moderation/base.py b/api/core/model_providers/models/moderation/base.py new file mode 100644 index 0000000000..0d56739f1a --- /dev/null +++ b/api/core/model_providers/models/moderation/base.py @@ -0,0 +1,29 @@ +from abc import abstractmethod +from typing import Any + +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import BaseModelProvider + + +class BaseModeration(BaseProviderModel): + name: str + type: ModelType = ModelType.MODERATION + + def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): + super().__init__(model_provider, client) + self.name = name + + def run(self, text: str) -> bool: + try: + return self._run(text) + except Exception as ex: + raise self.handle_exceptions(ex) + + @abstractmethod + def _run(self, text: str) -> bool: + raise NotImplementedError + + @abstractmethod + def handle_exceptions(self, ex: Exception) -> Exception: + raise NotImplementedError diff --git a/api/core/model_providers/models/moderation/openai_moderation.py b/api/core/model_providers/models/moderation/openai_moderation.py index e7012a0438..9aeb6f0292 100644 --- a/api/core/model_providers/models/moderation/openai_moderation.py +++ b/api/core/model_providers/models/moderation/openai_moderation.py @@ -4,29 +4,35 @@ import openai from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ LLMRateLimitError, LLMAuthorizationError -from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.models.moderation.base import BaseModeration from core.model_providers.providers.base import BaseModelProvider -DEFAULT_AUDIO_MODEL = 'whisper-1' +DEFAULT_MODEL = 'whisper-1' -class OpenAIModeration(BaseProviderModel): - type: ModelType = ModelType.MODERATION +class OpenAIModeration(BaseModeration): def __init__(self, model_provider: BaseModelProvider, name: str): - super().__init__(model_provider, openai.Moderation) + super().__init__(model_provider, openai.Moderation, name) - def run(self, text): + def _run(self, text: str) -> bool: credentials = self.model_provider.get_model_credentials( - model_name=DEFAULT_AUDIO_MODEL, + model_name=self.name, model_type=self.type ) - try: - return self._client.create(input=text, api_key=credentials['openai_api_key']) - except Exception as ex: - raise self.handle_exceptions(ex) + # 2000 text per chunk + length = 2000 + chunks = [text[i:i + length] for i in range(0, len(text), length)] + + moderation_result = self._client.create(input=chunks, + api_key=credentials['openai_api_key']) + + for result in moderation_result.results: + if result['flagged'] is True: + return False + + return True def handle_exceptions(self, ex: Exception) -> Exception: if isinstance(ex, openai.error.InvalidRequestError): diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index cceb9db1a9..5110650ec3 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -1,6 +1,7 @@ import math from typing import Optional +from flask import current_app from langchain import WikipediaAPIWrapper from langchain.callbacks.manager import Callbacks from langchain.memory.chat_memory import BaseChatMemory @@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain +from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule from core.conversation_message_task import ConversationMessageTask from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory @@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool from extensions.ext_database import db from models.dataset import Dataset, DatasetProcessRule from models.model import AppModelConfig +from models.provider import ProviderType class OrchestratorRuleParser: @@ -63,7 +65,7 @@ class OrchestratorRuleParser: # add agent callback to record agent thoughts agent_callback = AgentLoopGatherCallbackHandler( - model_instant=agent_model_instance, + model_instance=agent_model_instance, conversation_message_task=conversation_message_task ) @@ -123,23 +125,45 @@ class OrchestratorRuleParser: return chain - def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \ + def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \ -> Optional[SensitiveWordAvoidanceChain]: """ Convert app sensitive word avoidance config to chain + :param model_instance: model instance + :param callbacks: callbacks for the chain :param kwargs: :return: """ - if not self.app_model_config.sensitive_word_avoidance_dict: - return None + sensitive_word_avoidance_rule = None - sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict - sensitive_words = sensitive_word_avoidance_config.get("words", "") - if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words: + if self.app_model_config.sensitive_word_avoidance_dict: + sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict + if sensitive_word_avoidance_config.get("enabled", False): + if sensitive_word_avoidance_config.get('type') == 'moderation': + sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( + type=SensitiveWordAvoidanceRule.Type.MODERATION, + canned_response=sensitive_word_avoidance_config.get("canned_response") + if sensitive_word_avoidance_config.get("canned_response") + else 'Your content violates our usage policy. Please revise and try again.', + ) + else: + sensitive_words = sensitive_word_avoidance_config.get("words", "") + if sensitive_words: + sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( + type=SensitiveWordAvoidanceRule.Type.KEYWORDS, + canned_response=sensitive_word_avoidance_config.get("canned_response") + if sensitive_word_avoidance_config.get("canned_response") + else 'Your content violates our usage policy. Please revise and try again.', + extra_params={ + 'sensitive_words': sensitive_words.split(','), + } + ) + + if sensitive_word_avoidance_rule: return SensitiveWordAvoidanceChain( - sensitive_words=sensitive_words.split(","), - canned_response=sensitive_word_avoidance_config.get("canned_response", ''), + model_instance=model_instance, + sensitive_word_avoidance_rule=sensitive_word_avoidance_rule, output_key="sensitive_word_avoidance_output", callbacks=callbacks, **kwargs diff --git a/api/tests/integration_tests/models/moderation/test_openai_moderation.py b/api/tests/integration_tests/models/moderation/test_openai_moderation.py index c27f43e141..91027210bd 100644 --- a/api/tests/integration_tests/models/moderation/test_openai_moderation.py +++ b/api/tests/integration_tests/models/moderation/test_openai_moderation.py @@ -2,7 +2,7 @@ import json import os from unittest.mock import patch -from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL +from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL from core.model_providers.providers.openai_provider import OpenAIProvider from models.provider import Provider, ProviderType @@ -23,7 +23,7 @@ def get_mock_openai_moderation_model(): openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) return OpenAIModeration( model_provider=openai_provider, - name=DEFAULT_AUDIO_MODEL + name=DEFAULT_MODEL ) @@ -36,5 +36,4 @@ def test_run(mock_decrypt): model = get_mock_openai_moderation_model() rst = model.run('hello') - assert isinstance(rst, dict) - assert 'id' in rst + assert rst is True