mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
chore: apply flake8-comprehensions Ruff rules to improve collection comprehensions (#5652)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
2e718b85e9
commit
dcb72e0067
|
@ -40,7 +40,7 @@ class AgentConfigManager:
|
|||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
|
||||
'tool_parameters': tool.get('tool_parameters', {})
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
|
|
@ -59,7 +59,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
inputs = args['inputs']
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||
}
|
||||
|
||||
# get conversation
|
||||
|
|
|
@ -57,7 +57,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
inputs = args['inputs']
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
|
||||
# get conversation
|
||||
|
|
|
@ -203,7 +203,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
|
|
|
@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
inputs = args['inputs']
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
|
||||
# get conversation
|
||||
|
|
|
@ -66,8 +66,8 @@ class ProviderConfiguration(BaseModel):
|
|||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if (any([len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations])
|
||||
if (any(len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations)
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
|
||||
|
|
|
@ -397,7 +397,7 @@ class IndexingRunner:
|
|||
document_id=dataset_document.id,
|
||||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
|
||||
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
||||
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -83,7 +83,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
|
|
@ -175,8 +175,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
client = AnthropicBedrock(
|
||||
aws_access_key=credentials.get("aws_access_key_id", None),
|
||||
aws_secret_key=credentials.get("aws_secret_access_key", None),
|
||||
aws_access_key=credentials.get("aws_access_key_id"),
|
||||
aws_secret_key=credentials.get("aws_secret_access_key"),
|
||||
aws_region=credentials["aws_region"],
|
||||
)
|
||||
|
||||
|
@ -576,7 +576,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
payload = {}
|
||||
model_prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
|
@ -648,8 +648,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
runtime_client = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", None),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key", None)
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key")
|
||||
)
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
|
|
|
@ -49,8 +49,8 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
bedrock_runtime = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", None),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key", None)
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key")
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
|
@ -148,7 +148,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
payload = {}
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload['inputText'] = texts
|
||||
|
|
|
@ -696,12 +696,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
en_US=model
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[feature for feature in base_model_schema_features],
|
||||
features=list(base_model_schema_features),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
key: property for key, property in base_model_schema_model_properties.items()
|
||||
},
|
||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||
model_properties=dict(base_model_schema_model_properties.items()),
|
||||
parameter_rules=list(base_model_schema_parameters_rules),
|
||||
pricing=base_model_schema.pricing
|
||||
)
|
||||
|
||||
|
|
|
@ -277,10 +277,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps({
|
||||
key: value
|
||||
for key, value in part.function_call.args.items()
|
||||
})
|
||||
arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
)
|
||||
)
|
||||
]
|
||||
|
|
|
@ -88,9 +88,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and set([
|
||||
if model_schema and {
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
]).intersection(model_schema.features or []):
|
||||
}.intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
|
|
|
@ -100,10 +100,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials['server_url'] if 'server_url' in credentials else None
|
||||
server_url = credentials.get('server_url')
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
|
@ -182,10 +182,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
if stream:
|
||||
headers['Accept'] = 'text/event-stream'
|
||||
|
||||
endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials['server_url'] if 'server_url' in credentials else None
|
||||
server_url = credentials.get('server_url')
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
|
|
@ -1073,12 +1073,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
en_US=model
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[feature for feature in base_model_schema_features],
|
||||
features=list(base_model_schema_features),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
key: property for key, property in base_model_schema_model_properties.items()
|
||||
},
|
||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||
model_properties=dict(base_model_schema_model_properties.items()),
|
||||
parameter_rules=list(base_model_schema_parameters_rules),
|
||||
pricing=base_model_schema.pricing
|
||||
)
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
|
|
@ -275,14 +275,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
|||
|
||||
@classmethod
|
||||
def _get_parameter_type(cls, param_type: str) -> str:
|
||||
if param_type == 'integer':
|
||||
return 'int'
|
||||
elif param_type == 'number':
|
||||
return 'float'
|
||||
elif param_type == 'boolean':
|
||||
return 'boolean'
|
||||
elif param_type == 'string':
|
||||
return 'string'
|
||||
type_mapping = {
|
||||
'integer': 'int',
|
||||
'number': 'float',
|
||||
'boolean': 'boolean',
|
||||
'string': 'string'
|
||||
}
|
||||
return type_mapping.get(param_type)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
|
|
@ -80,7 +80,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
|||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
|
|
@ -579,10 +579,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps({
|
||||
key: value
|
||||
for key, value in part.function_call.args.items()
|
||||
})
|
||||
arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
)
|
||||
)
|
||||
]
|
||||
|
|
|
@ -102,7 +102,7 @@ class Signer:
|
|||
body_hash = Util.sha256(request.body)
|
||||
request.headers['X-Content-Sha256'] = body_hash
|
||||
|
||||
signed_headers = dict()
|
||||
signed_headers = {}
|
||||
for key in request.headers:
|
||||
if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
|
||||
signed_headers[key.lower()] = request.headers[key]
|
||||
|
|
|
@ -150,7 +150,7 @@ class Request:
|
|||
self.headers = OrderedDict()
|
||||
self.query = OrderedDict()
|
||||
self.body = ''
|
||||
self.form = dict()
|
||||
self.form = {}
|
||||
self.connection_timeout = 0
|
||||
self.socket_timeout = 0
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if is_completion_model:
|
||||
return sum([tokens(str(message.content)) for message in messages])
|
||||
return sum(tokens(str(message.content)) for message in messages)
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
|
|
@ -18,7 +18,7 @@ class _CommonZhipuaiAI:
|
|||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['api_key'] if 'api_key' in credentials else
|
||||
credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None,
|
||||
credentials.get("zhipuai_api_key"),
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
|
|
@ -148,7 +148,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
special_variable_keys.append('#histories#')
|
||||
|
||||
if query_in_prompt:
|
||||
prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}'
|
||||
prompt += prompt_rules.get('query_prompt', '{{#query#}}')
|
||||
special_variable_keys.append('#query#')
|
||||
|
||||
return {
|
||||
|
@ -234,8 +234,8 @@ class SimplePromptTransform(PromptTransform):
|
|||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
human_prefix=prompt_rules.get('human_prefix', 'Human'),
|
||||
ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant')
|
||||
)
|
||||
|
||||
# get prompt
|
||||
|
|
|
@ -417,7 +417,7 @@ class ProviderManager:
|
|||
model_load_balancing_enabled = cache_result == 'True'
|
||||
|
||||
if not model_load_balancing_enabled:
|
||||
return dict()
|
||||
return {}
|
||||
|
||||
provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
.filter(
|
||||
|
@ -451,7 +451,7 @@ class ProviderManager:
|
|||
if not provider_records:
|
||||
provider_records = []
|
||||
|
||||
provider_quota_to_provider_record_dict = dict()
|
||||
provider_quota_to_provider_record_dict = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
@ -661,7 +661,7 @@ class ProviderManager:
|
|||
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = dict()
|
||||
quota_type_to_provider_records_dict = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
|
|
@ -197,7 +197,7 @@ class Jieba(BaseKeyword):
|
|||
chunk_indices_count[node_id] += 1
|
||||
|
||||
sorted_chunk_indices = sorted(
|
||||
list(chunk_indices_count.keys()),
|
||||
chunk_indices_count.keys(),
|
||||
key=lambda x: chunk_indices_count[x],
|
||||
reverse=True,
|
||||
)
|
||||
|
|
|
@ -201,7 +201,7 @@ class ReactMultiDatasetRouter:
|
|||
tool_strings.append(
|
||||
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
unique_tool_names = {tool.name for tool in tools}
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
|
|
|
@ -105,15 +105,15 @@ class BingSearchTool(BuiltinTool):
|
|||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||
key = credentials.get('subscription_key', None)
|
||||
key = credentials.get('subscription_key')
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = credentials.get('server_url', None)
|
||||
server_url = credentials.get('server_url')
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
query = tool_parameters.get('query')
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
|
@ -170,7 +170,7 @@ class BingSearchTool(BuiltinTool):
|
|||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
query = tool_parameters.get('query')
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
|
|
|
@ -16,12 +16,12 @@ class BarChartTool(BuiltinTool):
|
|||
data = data.split(';')
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
axis = tool_parameters.get('x_axis', None) or None
|
||||
axis = tool_parameters.get('x_axis') or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
if len(axis) != len(data):
|
||||
|
|
|
@ -17,14 +17,14 @@ class LinearChartTool(BuiltinTool):
|
|||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
|
||||
axis = tool_parameters.get('x_axis', None) or None
|
||||
axis = tool_parameters.get('x_axis') or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
|
|
@ -16,10 +16,10 @@ class PieChartTool(BuiltinTool):
|
|||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
categories = tool_parameters.get('categories', None) or None
|
||||
categories = tool_parameters.get('categories') or None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
|
|
@ -37,10 +37,10 @@ class GaodeRepositoriesTool(BuiltinTool):
|
|||
apikey=self.runtime.credentials.get('api_key')))
|
||||
weatherInfo_data = weatherInfo_response.json()
|
||||
if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK':
|
||||
contents = list()
|
||||
contents = []
|
||||
if len(weatherInfo_data.get('forecasts')) > 0:
|
||||
for item in weatherInfo_data['forecasts'][0]['casts']:
|
||||
content = dict()
|
||||
content = {}
|
||||
content['date'] = item.get('date')
|
||||
content['week'] = item.get('week')
|
||||
content['dayweather'] = item.get('dayweather')
|
||||
|
|
|
@ -39,10 +39,10 @@ class GihubRepositoriesTool(BuiltinTool):
|
|||
f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc")
|
||||
response_data = response.json()
|
||||
if response.status_code == 200 and isinstance(response_data.get('items'), list):
|
||||
contents = list()
|
||||
contents = []
|
||||
if len(response_data.get('items')) > 0:
|
||||
for item in response_data.get('items'):
|
||||
content = dict()
|
||||
content = {}
|
||||
updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ")
|
||||
content['owner'] = item['owner']['login']
|
||||
content['name'] = item['name']
|
||||
|
|
|
@ -26,11 +26,11 @@ class JinaReaderTool(BuiltinTool):
|
|||
if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'):
|
||||
headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key')
|
||||
|
||||
target_selector = tool_parameters.get('target_selector', None)
|
||||
target_selector = tool_parameters.get('target_selector')
|
||||
if target_selector is not None and target_selector != '':
|
||||
headers['X-Target-Selector'] = target_selector
|
||||
|
||||
wait_for_selector = tool_parameters.get('wait_for_selector', None)
|
||||
wait_for_selector = tool_parameters.get('wait_for_selector')
|
||||
if wait_for_selector is not None and wait_for_selector != '':
|
||||
headers['X-Wait-For-Selector'] = wait_for_selector
|
||||
|
||||
|
@ -43,7 +43,7 @@ class JinaReaderTool(BuiltinTool):
|
|||
if tool_parameters.get('gather_all_images_at_the_end', False):
|
||||
headers['X-With-Images-Summary'] = 'true'
|
||||
|
||||
proxy_server = tool_parameters.get('proxy_server', None)
|
||||
proxy_server = tool_parameters.get('proxy_server')
|
||||
if proxy_server is not None and proxy_server != '':
|
||||
headers['X-Proxy-Url'] = proxy_server
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class JinaSearchTool(BuiltinTool):
|
|||
if tool_parameters.get('gather_all_images_at_the_end', False):
|
||||
headers['X-With-Images-Summary'] = 'true'
|
||||
|
||||
proxy_server = tool_parameters.get('proxy_server', None)
|
||||
proxy_server = tool_parameters.get('proxy_server')
|
||||
if proxy_server is not None and proxy_server != '':
|
||||
headers['X-Proxy-Url'] = proxy_server
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ class GoogleTool(BuiltinTool):
|
|||
google_domain = tool_parameters.get("google_domain", "google.com")
|
||||
gl = tool_parameters.get("gl", "us")
|
||||
hl = tool_parameters.get("hl", "en")
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
api_key = self.runtime.credentials['searchapi_api_key']
|
||||
result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location)
|
||||
|
|
|
@ -72,11 +72,11 @@ class GoogleJobsTool(BuiltinTool):
|
|||
"""
|
||||
query = tool_parameters['query']
|
||||
result_type = tool_parameters['result_type']
|
||||
is_remote = tool_parameters.get("is_remote", None)
|
||||
is_remote = tool_parameters.get("is_remote")
|
||||
google_domain = tool_parameters.get("google_domain", "google.com")
|
||||
gl = tool_parameters.get("gl", "us")
|
||||
hl = tool_parameters.get("hl", "en")
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
ltype = 1 if is_remote else None
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ class GoogleNewsTool(BuiltinTool):
|
|||
google_domain = tool_parameters.get("google_domain", "google.com")
|
||||
gl = tool_parameters.get("gl", "us")
|
||||
hl = tool_parameters.get("hl", "en")
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
api_key = self.runtime.credentials['searchapi_api_key']
|
||||
result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location)
|
||||
|
|
|
@ -107,7 +107,7 @@ class SearXNGSearchTool(BuiltinTool):
|
|||
if not host:
|
||||
raise Exception('SearXNG api is required')
|
||||
|
||||
query = tool_parameters.get('query', None)
|
||||
query = tool_parameters.get('query')
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class GetMarkdownTool(BuiltinTool):
|
|||
Invoke the SerplyApi tool.
|
||||
"""
|
||||
url = tool_parameters["url"]
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
api_key = self.runtime.credentials["serply_api_key"]
|
||||
result = SerplyApi(api_key).run(url, location=location)
|
||||
|
|
|
@ -55,7 +55,7 @@ class SerplyApi:
|
|||
f"Employer: {job['employer']}",
|
||||
f"Location: {job['location']}",
|
||||
f"Link: {job['link']}",
|
||||
f"""Highest: {", ".join([h for h in job["highlights"]])}""",
|
||||
f"""Highest: {", ".join(list(job["highlights"]))}""",
|
||||
"---",
|
||||
])
|
||||
)
|
||||
|
@ -78,7 +78,7 @@ class JobSearchTool(BuiltinTool):
|
|||
query = tool_parameters["query"]
|
||||
gl = tool_parameters.get("gl", "us")
|
||||
hl = tool_parameters.get("hl", "en")
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
api_key = self.runtime.credentials["serply_api_key"]
|
||||
result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location)
|
||||
|
|
|
@ -80,7 +80,7 @@ class NewsSearchTool(BuiltinTool):
|
|||
query = tool_parameters["query"]
|
||||
gl = tool_parameters.get("gl", "us")
|
||||
hl = tool_parameters.get("hl", "en")
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
api_key = self.runtime.credentials["serply_api_key"]
|
||||
result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location)
|
||||
|
|
|
@ -83,7 +83,7 @@ class ScholarSearchTool(BuiltinTool):
|
|||
query = tool_parameters["query"]
|
||||
gl = tool_parameters.get("gl", "us")
|
||||
hl = tool_parameters.get("hl", "en")
|
||||
location = tool_parameters.get("location", None)
|
||||
location = tool_parameters.get("location")
|
||||
|
||||
api_key = self.runtime.credentials["serply_api_key"]
|
||||
result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location)
|
||||
|
|
|
@ -38,7 +38,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
super().__init__(**{
|
||||
'identity': provider_yaml['identity'],
|
||||
'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None,
|
||||
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
|
||||
})
|
||||
|
||||
def _get_builtin_tools(self) -> list[Tool]:
|
||||
|
|
|
@ -159,8 +159,8 @@ class ApiTool(Tool):
|
|||
for content_type in self.api_bundle.openapi['requestBody']['content']:
|
||||
headers['Content-Type'] = content_type
|
||||
body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema']
|
||||
required = body_schema['required'] if 'required' in body_schema else []
|
||||
properties = body_schema['properties'] if 'properties' in body_schema else {}
|
||||
required = body_schema.get('required', [])
|
||||
properties = body_schema.get('properties', {})
|
||||
for name, property in properties.items():
|
||||
if name in parameters:
|
||||
# convert type
|
||||
|
|
|
@ -90,7 +90,7 @@ class DatasetRetrieverTool(Tool):
|
|||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
query = tool_parameters.get('query', None)
|
||||
query = tool_parameters.get('query')
|
||||
if not query:
|
||||
return self.create_text_message(text='please input query')
|
||||
|
||||
|
|
|
@ -209,7 +209,7 @@ class ToolManager:
|
|||
|
||||
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter_rule.options))
|
||||
options = [x.value for x in parameter_rule.options]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
|
||||
|
|
|
@ -21,10 +21,7 @@ class ApiBasedToolSchemaParser:
|
|||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
if 'description' in openapi['info']:
|
||||
extra_info['description'] = openapi['info']['description']
|
||||
else:
|
||||
extra_info['description'] = ''
|
||||
extra_info['description'] = openapi['info'].get('description', '')
|
||||
|
||||
if len(openapi['servers']) == 0:
|
||||
raise ToolProviderNotFoundError('No server found in the openapi yaml.')
|
||||
|
@ -95,8 +92,8 @@ class ApiBasedToolSchemaParser:
|
|||
# parse body parameters
|
||||
if 'schema' in interface['operation']['requestBody']['content'][content_type]:
|
||||
body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
|
||||
required = body_schema['required'] if 'required' in body_schema else []
|
||||
properties = body_schema['properties'] if 'properties' in body_schema else {}
|
||||
required = body_schema.get('required', [])
|
||||
properties = body_schema.get('properties', {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
|
@ -105,14 +102,14 @@ class ApiBasedToolSchemaParser:
|
|||
zh_Hans=name
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=property['description'] if 'description' in property else '',
|
||||
zh_Hans=property['description'] if 'description' in property else ''
|
||||
en_US=property.get('description', ''),
|
||||
zh_Hans=property.get('description', '')
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property['description'] if 'description' in property else '',
|
||||
default=property['default'] if 'default' in property else None,
|
||||
llm_description=property.get('description', ''),
|
||||
default=property.get('default', None),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
|
@ -149,7 +146,7 @@ class ApiBasedToolSchemaParser:
|
|||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['description'] if 'description' in interface['operation'] else
|
||||
interface['operation']['summary'] if 'summary' in interface['operation'] else None,
|
||||
interface['operation'].get('summary', None),
|
||||
operation_id=interface['operation']['operationId'],
|
||||
parameters=parameters,
|
||||
author='',
|
||||
|
|
|
@ -283,7 +283,7 @@ def strip_control_characters(text):
|
|||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
|
||||
control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'}
|
||||
retained_chars = ['\t', '\n', '\r', '\f']
|
||||
|
||||
# Remove non-printing control characters
|
||||
|
|
|
@ -93,7 +93,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
|
@ -644,7 +644,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]):
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
|
|
@ -246,10 +246,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
}
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if 'results' in response_json:
|
||||
results = response_json['results']
|
||||
else:
|
||||
results = []
|
||||
results = response_json.get('results', [])
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
|
@ -293,8 +290,5 @@ class NotionOAuth(OAuthDataSource):
|
|||
}
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if 'results' in response_json:
|
||||
results = response_json['results']
|
||||
else:
|
||||
results = []
|
||||
results = response_json.get('results', [])
|
||||
return results
|
||||
|
|
|
@ -14,9 +14,11 @@ line-length = 120
|
|||
preview = true
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"F", # pyflakes rules
|
||||
"I", # isort rules
|
||||
"UP", # pyupgrade rules
|
||||
"UP", # pyupgrade rules
|
||||
"B035", # static-key-dict-comprehension
|
||||
"E101", # mixed-spaces-and-tabs
|
||||
"E111", # indentation-with-invalid-multiple
|
||||
"E112", # no-indented-block
|
||||
|
@ -28,8 +30,13 @@ select = [
|
|||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM116", # if-else-block-instead-of-dict-lookup
|
||||
"SIM401", # if-else-block-instead-of-dict-get
|
||||
"SIM910", # dict-get-with-none-default
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
"F601", # multi-value-repeated-key-literal
|
||||
"F602", # multi-value-repeated-key-variable
|
||||
]
|
||||
ignore = [
|
||||
"F403", # undefined-local-with-import-star
|
||||
|
@ -82,8 +89,8 @@ HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b"
|
|||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c"
|
||||
MOCK_SWITCH = "true"
|
||||
CODE_MAX_STRING_LENGTH = "80000"
|
||||
CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
|
||||
CODE_EXECUTION_API_KEY="dify-sandbox"
|
||||
CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
|
||||
CODE_EXECUTION_API_KEY = "dify-sandbox"
|
||||
FIRECRAWL_API_KEY = "fc-"
|
||||
|
||||
[tool.poetry]
|
||||
|
@ -114,11 +121,11 @@ cachetools = "~5.3.0"
|
|||
weaviate-client = "~3.21.0"
|
||||
mailchimp-transactional = "~1.0.50"
|
||||
scikit-learn = "1.2.2"
|
||||
sentry-sdk = {version = "~1.39.2", extras = ["flask"]}
|
||||
sentry-sdk = { version = "~1.39.2", extras = ["flask"] }
|
||||
sympy = "1.12"
|
||||
jieba = "0.42.1"
|
||||
celery = "~5.3.6"
|
||||
redis = {version = "~5.0.3", extras = ["hiredis"]}
|
||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||
chardet = "~5.1.0"
|
||||
python-docx = "~1.1.0"
|
||||
pypdfium2 = "~4.17.0"
|
||||
|
@ -138,7 +145,7 @@ googleapis-common-protos = "1.63.0"
|
|||
google-cloud-storage = "2.16.0"
|
||||
replicate = "~0.22.0"
|
||||
websocket-client = "~1.7.0"
|
||||
dashscope = {version = "~1.17.0", extras = ["tokenizer"]}
|
||||
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
|
||||
huggingface-hub = "~0.16.4"
|
||||
transformers = "~4.35.0"
|
||||
tokenizers = "~0.15.0"
|
||||
|
@ -152,10 +159,10 @@ qdrant-client = "1.7.3"
|
|||
cohere = "~5.2.4"
|
||||
pyyaml = "~6.0.1"
|
||||
numpy = "~1.26.4"
|
||||
unstructured = {version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"]}
|
||||
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
|
||||
bs4 = "~0.0.1"
|
||||
markdown = "~3.5.1"
|
||||
httpx = {version = "~0.27.0", extras = ["socks"]}
|
||||
httpx = { version = "~0.27.0", extras = ["socks"] }
|
||||
matplotlib = "~3.8.2"
|
||||
yfinance = "~0.2.40"
|
||||
pydub = "~0.25.1"
|
||||
|
@ -180,7 +187,7 @@ pgvector = "0.2.5"
|
|||
pymysql = "1.1.1"
|
||||
tidb-vector = "0.0.9"
|
||||
google-cloud-aiplatform = "1.49.0"
|
||||
vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
|
||||
vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
|
||||
kaleido = "0.2.1"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
||||
tcvectordb = "1.3.2"
|
||||
|
|
|
@ -696,7 +696,7 @@ class DocumentService:
|
|||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
exist_page_ids = []
|
||||
exist_document = dict()
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
|
@ -95,7 +95,7 @@ class RecommendedAppService:
|
|||
|
||||
categories.add(recommended_app.category) # add category to categories
|
||||
|
||||
return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))}
|
||||
return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)}
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
|
|
|
@ -514,8 +514,8 @@ class WorkflowConverter:
|
|||
|
||||
prompt_rules = prompt_template_config['prompt_rules']
|
||||
role_prefix = {
|
||||
"user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
"assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
"user": prompt_rules.get('human_prefix', 'Human'),
|
||||
"assistant": prompt_rules.get('assistant_prefix', 'Assistant')
|
||||
}
|
||||
else:
|
||||
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
|
||||
|
|
|
@ -112,7 +112,7 @@ def test_execute_llm(setup_openai_mock):
|
|||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
|
||||
node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
|
@ -229,7 +229,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
|
||||
node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
|
|
|
@ -59,7 +59,7 @@ def get_mocked_fetch_model_config(
|
|||
provider_model_bundle=provider_model_bundle
|
||||
)
|
||||
|
||||
return MagicMock(return_value=tuple([model_instance, model_config]))
|
||||
return MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_function_calling_parameter_extractor(setup_openai_mock):
|
||||
|
|
|
@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages():
|
|||
prompt_rules = prompt_template['prompt_rules']
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant")
|
||||
)}
|
||||
real_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user