feat: add hosted moderation (#1158)

This commit is contained in:
takatost 2023-09-12 10:26:12 +08:00 committed by GitHub
parent 983834cd52
commit f9082104ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 240 additions and 69 deletions

View File

@ -61,6 +61,8 @@ DEFAULTS = {
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
@ -230,6 +232,9 @@ class Config:
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@ -116,6 +118,18 @@ class AgentExecutor:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
query
)
if not moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
@ -128,7 +142,9 @@ class AgentExecutor:
try:
output = agent_executor.run(query)
except Exception:
except LLMError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
output = None

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_instant = model_instant
self.model_instance = model_instance
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)
@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
latency: float = 0.0

View File

@ -1,5 +1,4 @@
import logging
import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)

View File

@ -1,15 +1,38 @@
import enum
import logging
from typing import List, Dict, Optional, Any
import openai
from flask import current_app
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from openai import InvalidRequestError
from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
AuthenticationError, OpenAIError
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
sensitive_words: List[str] = []
canned_response: str = None
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> str:
for word in self.sensitive_words:
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return self.canned_response
return text
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}

View File

@ -1,9 +1,7 @@
import json
import logging
import re
from typing import Optional, List, Union, Tuple
from typing import Optional, List, Union
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
@ -81,7 +78,7 @@ class Completion:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)

View File

@ -1,5 +1,5 @@
import decimal
import json
import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop
@ -23,6 +23,8 @@ class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id
self.app = app
@ -61,6 +63,7 @@ class ConversationMessageTask:
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = self.app_model_config.to_dict()
@ -165,7 +168,7 @@ class ConversationMessageTask:
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price
db.session.commit()
@ -220,18 +223,18 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
@ -245,7 +248,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instant.get_currency()
message_agent_thought.currency = agent_model_instance.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

View File

@ -0,0 +1,32 @@
import logging
import openai
from flask import current_app
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from models.provider import ProviderType
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in moderation_providers:
# 2000 text per chunk
length = 2000
chunks = [text[i:i + length] for i in range(0, len(text), length)]
try:
moderation_result = openai.Moderation.create(input=chunks,
api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
@ -180,7 +181,7 @@ class ModelFactory:
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
model_name: str) -> Optional[BaseModeration]:
"""
get moderation model.

View File

@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
:param callbacks:
:return:
"""
moderation_result = moderation.check_moderation(
self.model_provider,
"\n".join([message.content for message in messages])
)
if not moderation_result:
kwargs['fake_response'] = "I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest."
if self.deduct_quota:
self.model_provider.check_quota_over_limit()

View File

@ -0,0 +1,29 @@
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseModeration(BaseProviderModel):
name: str
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, text: str) -> bool:
try:
return self._run(text)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, text: str) -> bool:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@ -4,29 +4,35 @@ import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1'
DEFAULT_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel):
type: ModelType = ModelType.MODERATION
class OpenAIModeration(BaseModeration):
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation)
super().__init__(model_provider, openai.Moderation, name)
def run(self, text):
def _run(self, text: str) -> bool:
credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL,
model_name=self.name,
model_type=self.type
)
try:
return self._client.create(input=text, api_key=credentials['openai_api_key'])
except Exception as ex:
raise self.handle_exceptions(ex)
# 2000 text per chunk
length = 2000
chunks = [text[i:i + length] for i in range(0, len(text), length)]
moderation_result = self._client.create(input=chunks,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):

View File

@ -1,6 +1,7 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@ -63,7 +65,7 @@ class OrchestratorRuleParser:
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_instant=agent_model_instance,
model_instance=agent_model_instance,
conversation_message_task=conversation_message_task
)
@ -123,23 +125,45 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_rule = None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
model_instance=model_instance,
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs

View File

@ -2,7 +2,7 @@ import json
import os
from unittest.mock import patch
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType
@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModeration(
model_provider=openai_provider,
name=DEFAULT_AUDIO_MODEL
name=DEFAULT_MODEL
)
@ -36,5 +36,4 @@ def test_run(mock_decrypt):
model = get_mock_openai_moderation_model()
rst = model.run('hello')
assert isinstance(rst, dict)
assert 'id' in rst
assert rst is True