Merge branch 'main' of github.com:langgenius/dify into feat/plugins

This commit is contained in:
Yi 2024-09-18 20:57:52 +08:00
commit e8127756e0
245 changed files with 2882 additions and 1210 deletions

9
.gitignore vendored
View File

@ -153,6 +153,9 @@ docker-legacy/volumes/etcd/*
docker-legacy/volumes/minio/*
docker-legacy/volumes/milvus/*
docker-legacy/volumes/chroma/*
docker-legacy/volumes/opensearch/data/*
docker-legacy/volumes/pgvectors/data/*
docker-legacy/volumes/pgvector/data/*
docker/volumes/app/storage/*
docker/volumes/certbot/*
@ -164,6 +167,12 @@ docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
docker/volumes/chroma/*
docker/volumes/opensearch/data/*
docker/volumes/myscale/data/*
docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf
docker/middleware.env

View File

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in ["console", "inner_api"]:
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")

View File

@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in ["knowledge", "all"]:
if scope in {"knowledge", "all"}:
migrate_knowledge_vector_database()
if scope in ["annotation", "all"]:
if scope in {"annotation", "all"}:
migrate_annotation_vector_database()

View File

@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):

View File

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info)
# Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value:

View File

@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in ["completed", "error"]:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in ["completed", "error"]:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
# format document files info
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
elif action == "resume":
if document.indexing_status not in ["paused", "error"]:
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None

View File

@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):

View File

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"],
"editable": current_user.role in {"owner", "admin"},
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
}
for installed_app in installed_apps

View File

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)

View File

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters."""
app_model = installed_app.app
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError()
extension = file.filename.split(".")[-1]
if extension.lower() not in ["svg", "png"]:
if extension.lower() not in {"svg", "png"}:
raise UnsupportedFileTypeError()
try:

View File

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields)
def get(self, app_model: App):
"""Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -79,7 +79,7 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):

View File

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:

View File

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -78,7 +78,7 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):

View File

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError()
message_id = str(message_id)

View File

@ -90,7 +90,7 @@ class CotAgentOutputParser:
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ["\n", " ", ""]:
if last_character not in {"\n", " ", ""}:
index += steps
yield delta
continue
@ -117,7 +117,7 @@ class CotAgentOutputParser:
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ["\n", " ", ""]:
if last_character not in {"\n", " ", ""}:
index += steps
yield delta
continue

View File

@ -29,7 +29,7 @@ class BaseAppConfigManager:
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
)
additional_features.opening_statement, additional_features.suggested_questions = (

View File

@ -18,7 +18,7 @@ class AgentConfigManager:
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == "cot" or agent_strategy == "react":
elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
@ -43,10 +43,10 @@ class AgentConfigManager:
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",
"router",
]:
}:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")

View File

@ -167,7 +167,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key == "dataset":

View File

@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
variable=variable["variable"], type=variable["type"], config=variable["config"]
)
)
elif variable_type in [
elif variable_type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
}:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]

View File

@ -54,14 +54,14 @@ class FileUploadConfigManager:
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in ["high", "low"]:
if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ["remote_url", "local_file"]:
if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]

View File

@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
raise ValueError("Workflow not initialized")
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:

View File

@ -22,11 +22,11 @@ class BaseAppGenerator:
return var.default or ""
if (
var.type
in (
in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
}
and user_input_value
and not isinstance(user_input_value, str)
):
@ -44,7 +44,7 @@ class BaseAppGenerator:
options = var.options or []
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")

View File

@ -32,7 +32,7 @@ class AppQueueManager:
self._user_id = user_id
self._invoke_from = invoke_from
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
@ -118,7 +118,7 @@ class AppQueueManager:
if result is None:
return
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return

View File

@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api"
end_user_id = application_generate_entity.user_id
else:
@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model
override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
]:
}:
override_model_configs = app_config.app_model_config_dict
# get conversation introduction

View File

@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,

View File

@ -63,7 +63,7 @@ class AnnotationReplyFeature:
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api"
else:
from_source = "console"

View File

@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)

View File

@ -383,7 +383,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
response = NodeStartStreamResponse(
@ -430,7 +430,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
return NodeFinishStreamResponse(

View File

@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
source="app",
source_app_id=self._app_id,
created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
),
created_by=self._user_id,
)

View File

@ -65,7 +65,7 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except Exception as e:
logging.exception("Failed transform embedding: ", e)
logging.exception("Failed transform embedding: %s", e)
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
@ -85,7 +85,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.error("Failed to embed documents: ", ex)
logger.error("Failed to embed documents: %s", ex)
raise ex
return text_embeddings
@ -116,10 +116,7 @@ class CacheEmbedding(Embeddings):
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except IntegrityError:
db.session.rollback()
except:
logging.exception("Failed to add embedding to redis")
except Exception as ex:
logging.exception("Failed to add embedding to redis %s", ex)
return embedding_results

View File

@ -292,7 +292,7 @@ class IndexingRunner:
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return []
data_source_info = dataset_document.data_source_info_dict

View File

@ -52,7 +52,7 @@ class TokenBufferMemory:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
if message.workflow_run_id:

View File

@ -27,17 +27,17 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
if origin_model_type in {"text-generation", cls.LLM.value}:
return cls.LLM
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
return cls.TEXT_EMBEDDING
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
elif origin_model_type in {"reranking", cls.RERANK.value}:
return cls.RERANK
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
return cls.SPEECH2TEXT
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
elif origin_model_type in {"tts", cls.TTS.value}:
return cls.TTS
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION

View File

@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
for i in range(len(sentences))
]
for future in futures:
yield from future.result().__enter__().iter_bytes(1024)
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else:
response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip()
)
yield from response.__enter__().iter_bytes(1024)
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -1,6 +1,6 @@
model: eu.anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku(Cross Region Inference)
en_US: Claude 3 Haiku(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -1,6 +1,6 @@
model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0
label:
en_US: Claude 3.5 Sonnet(Cross Region Inference)
en_US: Claude 3.5 Sonnet(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -1,6 +1,6 @@
model: eu.anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet(Cross Region Inference)
en_US: Claude 3 Sonnet(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -1,8 +1,8 @@
# standard import
import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
@ -17,7 +17,6 @@ from botocore.exceptions import (
ServiceNotInRegionError,
UnknownServiceError,
)
from PIL.Image import Image
# local import
from core.model_runtime.callbacks.base_callback import Callback
@ -443,8 +442,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
try:
url = message_content.data
image_content = requests.get(url).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg)
elif error_code in [
elif error_code in {
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
}:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku(Cross Region Inference)
en_US: Claude 3 Haiku(US.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-opus-20240229-v1:0
label:
en_US: Claude 3 Opus(Cross Region Inference)
en_US: Claude 3 Opus(US.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-5-sonnet-20240620-v1:0
label:
en_US: Claude 3.5 Sonnet(Cross Region Inference)
en_US: Claude 3.5 Sonnet(US.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet(Cross Region Inference)
en_US: Claude 3 Sonnet(US.Cross Region Inference)
model_type: llm
features:
- agent-thought

View File

@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg)
elif error_code in [
elif error_code in {
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
}:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

View File

@ -6,10 +6,10 @@ from collections.abc import Generator
from typing import Optional, Union, cast
import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions
import google.generativeai as genai
import google.generativeai.client as client
import requests
from google.api_core import exceptions
from google.generativeai import client
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part
from PIL import Image

View File

@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
if "huggingfacehub_api_token" not in credentials:
@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
credentials["huggingfacehub_api_token"], model
)
if credentials["task_type"] not in ("text2text-generation", "text-generation"):
if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be one of text2text-generation, text-generation."
)

View File

@ -49,8 +49,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
return RerankResult(model=model, docs=[])
server_url = credentials["server_url"]
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
try:
results = TeiHelper.invoke_rerank(server_url, query, docs)

View File

@ -75,7 +75,7 @@ class TeiHelper:
if len(model_type.keys()) < 1:
raise RuntimeError("model_type is empty")
model_type = list(model_type.keys())[0]
if model_type not in ["embedding", "reranker"]:
if model_type not in {"embedding", "reranker"}:
raise RuntimeError(f"invalid model_type: {model_type}")
max_input_length = response_json.get("max_input_length", 512)

View File

@ -42,8 +42,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
"""
server_url = credentials["server_url"]
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
# get model properties
context_size = self._get_context_size(model, credentials)
@ -119,8 +118,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
num_tokens = 0
server_url = credentials["server_url"]
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
num_tokens = sum(len(tokens) for tokens in batch_tokens)

View File

@ -2,3 +2,4 @@
- hunyuan-standard
- hunyuan-standard-256k
- hunyuan-pro
- hunyuan-turbo

View File

@ -0,0 +1,38 @@
model: hunyuan-turbo
label:
zh_Hans: hunyuan-turbo
en_US: hunyuan-turbo
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.015'
output: '0.05'
unit: '0.001'
currency: RMB

View File

@ -18,9 +18,9 @@ class JinaProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
# Use `jina-embeddings-v2-base-en` model for validate,
# Use `jina-embeddings-v3` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials)
model_instance.validate_credentials(model="jina-embeddings-v3", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:

View File

@ -48,8 +48,7 @@ class JinaRerankModel(RerankModel):
return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://api.jina.ai/v1")
if base_url.endswith("/"):
base_url = base_url[:-1]
base_url = base_url.removesuffix("/")
try:
response = httpx.post(

View File

@ -0,0 +1,9 @@
model: jina-embeddings-v3
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 2048
pricing:
input: '0.001'
unit: '0.001'
currency: USD

View File

@ -44,8 +44,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("api_key is required")
base_url = credentials.get("base_url", self.api_base)
if base_url.endswith("/"):
base_url = base_url[:-1]
base_url = base_url.removesuffix("/")
url = base_url + "/embeddings"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
@ -57,6 +56,9 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]}
if model == "jina-embeddings-v3":
data["task"] = "text-matching"
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:

View File

@ -100,9 +100,9 @@ class MinimaxChatCompletion:
return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
if code in {1000, 1001, 1013, 1027}:
raise InternalServerError(msg)
elif code == 1002 or code == 1039:
elif code in {1002, 1039}:
raise RateLimitReachedError(msg)
elif code == 1004:
raise InvalidAuthenticationError(msg)

View File

@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
if code in {1000, 1001, 1013, 1027}:
raise InternalServerError(msg)
elif code == 1002 or code == 1039:
elif code in {1002, 1039}:
raise RateLimitReachedError(msg)
elif code == 1004:
raise InvalidAuthenticationError(msg)

View File

@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key")
def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001:
if code in {1000, 1001}:
raise InternalServerError(msg)
elif code == 1002:
raise RateLimitReachedError(msg)

View File

@ -31,3 +31,4 @@ pricing:
output: '0.002'
unit: '0.001'
currency: USD
deprecated: true

View File

@ -31,3 +31,4 @@ pricing:
output: '0.004'
unit: '0.001'
currency: USD
deprecated: true

View File

@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model_mode = self.get_model_mode(base_model, credentials)
# transform response format
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
stop = stop or []
if model_mode == LLMMode.CHAT:
# chat model
@ -615,8 +615,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
block_as_stream = False
if model.startswith("o1"):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]

View File

@ -11,9 +11,9 @@ model_properties:
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 65563
default: 65536
min: 1
max: 65563
max: 65536
- name: response_format
label:
zh_Hans: 回复格式

View File

@ -11,9 +11,9 @@ model_properties:
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 65563
default: 65536
min: 1
max: 65563
max: 65536
- name: response_format
label:
zh_Hans: 回复格式

View File

@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
for i in range(len(sentences))
]
for future in futures:
yield from future.result().__enter__().iter_bytes(1024)
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else:
response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip()
)
yield from response.__enter__().iter_bytes(1024)
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
credentials["mode"] = self.get_model_mode(model).value
credentials["function_calling_type"] = "tool_call"
return
def _invoke(
self,

View File

@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
)
for key, value in input_properties:
if key not in ["system_prompt", "prompt"] and "stop" not in key:
if key not in {"system_prompt", "prompt"} and "stop" not in key:
value_type = value.get("type")
if not value_type:

View File

@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
)
for input_property in input_properties:
if input_property[0] in ("text", "texts", "inputs"):
if input_property[0] in {"text", "texts", "inputs"}:
text_input_key = input_property[0]
return text_input_key
@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
def _generate_embeddings_by_text_input_key(
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
) -> list[list[float]]:
if text_input_key in ("text", "inputs"):
if text_input_key in {"text", "inputs"}:
embeddings = []
for text in texts:
result = client.run(replicate_model_version, input={text_input_key: text})

View File

@ -30,8 +30,7 @@ class SiliconflowRerankModel(RerankModel):
return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1")
if base_url.endswith("/"):
base_url = base_url[:-1]
base_url = base_url.removesuffix("/")
try:
response = httpx.post(
base_url + "/rerank",

View File

@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return:
"""
if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
model = model.replace("-chat", "")
if model == "farui-plus":
model = "qwen-farui-plus"
@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
mode = self.get_model_mode(model, credentials)
if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
model = model.replace("-chat", "")
extra_model_kwargs = {}
@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response
"""
if response.status_code != 200 and response.status_code != HTTPStatus.OK:
if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError(response.message)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
full_text = ""
tool_calls = []
for index, response in enumerate(responses):
if response.status_code != 200 and response.status_code != HTTPStatus.OK:
if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError(
f"Failed to invoke model {model}, status code: {response.status_code}, "
f"message: {response.message}"

View File

@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
"""
Code block mode wrapper for invoking large language model
"""
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
stop = stop or []
self._transform_chat_json_prompts(
model=model,

View File

@ -5,7 +5,6 @@ import logging
from collections.abc import Generator
from typing import Optional, Union, cast
import google.api_core.exceptions as exceptions
import google.auth.transport.requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream
@ -17,6 +16,7 @@ from anthropic.types import (
MessageStopEvent,
MessageStreamEvent,
)
from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image
@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -96,7 +96,6 @@ class Signer:
signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
return
@staticmethod
def hashed_canonical_request_v4(request, meta):
@ -105,7 +104,7 @@ class Signer:
signed_headers = {}
for key in request.headers:
if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"):
if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
signed_headers[key.lower()] = request.headers[key]
if "host" in signed_headers:

View File

@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
"""
Code block mode wrapper for invoking large language model
"""
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
response_format = model_parameters["response_format"]
stop = stop or []
self._transform_json_prompts(

View File

@ -459,8 +459,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if "server_url" not in credentials:
raise CredentialsValidateFailedError("server_url is required in credentials")
if credentials["server_url"].endswith("/"):
credentials["server_url"] = credentials["server_url"][:-1]
credentials["server_url"] = credentials["server_url"].removesuffix("/")
api_key = credentials.get("api_key") or "abc"

View File

@ -50,8 +50,7 @@ class XinferenceRerankModel(RerankModel):
server_url = credentials["server_url"]
model_uid = credentials["model_uid"]
api_key = credentials.get("api_key")
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True}
@ -98,8 +97,7 @@ class XinferenceRerankModel(RerankModel):
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials["server_url"].endswith("/"):
credentials["server_url"] = credentials["server_url"][:-1]
credentials["server_url"] = credentials["server_url"].removesuffix("/")
# initialize client
client = Client(

View File

@ -45,8 +45,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials["server_url"].endswith("/"):
credentials["server_url"] = credentials["server_url"][:-1]
credentials["server_url"] = credentials["server_url"].removesuffix("/")
# initialize client
client = Client(
@ -116,8 +115,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
server_url = credentials["server_url"]
model_uid = credentials["model_uid"]
api_key = credentials.get("api_key")
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:

View File

@ -45,8 +45,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = credentials["server_url"]
model_uid = credentials["model_uid"]
api_key = credentials.get("api_key")
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
@ -118,8 +117,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
if extra_args.max_tokens:
credentials["max_tokens"] = extra_args.max_tokens
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
client = Client(
base_url=server_url,

View File

@ -73,8 +73,7 @@ class XinferenceText2SpeechModel(TTSModel):
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials["server_url"].endswith("/"):
credentials["server_url"] = credentials["server_url"][:-1]
credentials["server_url"] = credentials["server_url"].removesuffix("/")
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"],
@ -189,8 +188,7 @@ class XinferenceText2SpeechModel(TTSModel):
:param voice: model timbre
:return: text translated to audio file
"""
if credentials["server_url"].endswith("/"):
credentials["server_url"] = credentials["server_url"][:-1]
credentials["server_url"] = credentials["server_url"].removesuffix("/")
try:
api_key = credentials.get("api_key")

View File

@ -103,7 +103,7 @@ class XinferenceHelper:
model_handle_type = "embedding"
elif response_json.get("model_type") == "audio":
model_handle_type = "audio"
if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]:
if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
model_ability.append("text-to-audio")
else:
model_ability.append("audio-to-text")

View File

@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
new_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.copy()
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v'
if model not in ("glm-4v", "glm-4v-plus"):
if model not in {"glm-4v", "glm-4v-plus"}:
# not support list message
continue
# get image and
@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
):
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else:
if (
copy_prompt_message.role == PromptMessageRole.USER
or copy_prompt_message.role == PromptMessageRole.TOOL
):
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else:
new_prompt_messages.append(copy_prompt_message)
if model == "glm-4v" or model == "glm-4v-plus":
if model in {"glm-4v", "glm-4v-plus"}:
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:
params = {"model": model, "messages": [], **model_parameters}
@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# chatglm model
for prompt_message in new_prompt_messages:
# merge system message to user message
if (
prompt_message.role == PromptMessageRole.SYSTEM
or prompt_message.role == PromptMessageRole.TOOL
or prompt_message.role == PromptMessageRole.USER
):
if prompt_message.role in {
PromptMessageRole.SYSTEM,
PromptMessageRole.TOOL,
PromptMessageRole.USER,
}:
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
else:

View File

@ -127,8 +127,7 @@ class SSELineParser:
field, _p, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
value = value.removeprefix(" ")
if field == "data":
self._data.append(value)
elif field == "event":

View File

@ -1,5 +1,4 @@
from __future__ import annotations
from .fine_tuning_job import FineTuningJob as FineTuningJob
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
from .fine_tuning_job_event import FineTuningJobEvent

View File

@ -75,7 +75,7 @@ class CommonValidator:
if not isinstance(value, str):
raise ValueError(f"Variable {credential_form_schema.variable} should be string")
if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]:
if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
# If the value is in options, no validation is performed
if credential_form_schema.options:
if value not in [option.value for option in credential_form_schema.options]:
@ -83,7 +83,7 @@ class CommonValidator:
if credential_form_schema.type == FormType.SWITCH:
# If the value is not in ['true', 'false'], an exception is thrown
if value.lower() not in ["true", "false"]:
if value.lower() not in {"true", "false"}:
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
value = True if value.lower() == "true" else False

View File

@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ["http", "https"]:
if parsed_url.scheme in {"http", "https"}:
hosts = f"{config.host}:{config.port}"
else:
hosts = f"http://{config.host}:{config.port}"
@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector):
return uuids
def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__()
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:

View File

@ -35,7 +35,7 @@ class MyScaleVector(BaseVector):
super().__init__(collection_name)
self._config = config
self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
self._client = get_client(
host=config.host,
port=config.port,
@ -92,7 +92,7 @@ class MyScaleVector(BaseVector):
@staticmethod
def escape_str(value: Any) -> str:
return "".join(" " if c in ("\\", "'") else c for c in str(value))
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")

View File

@ -223,15 +223,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if (
pos == "nr"
or pos == "Ng"
or pos == "eng"
or pos == "nz"
or pos == "n"
or pos == "ORG"
or pos == "v"
): # nr: 人名, ns: 地名, nt: 机构名
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:

View File

@ -98,17 +98,17 @@ class ExtractProcessor:
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == "Unstructured":
if file_extension == ".xlsx" or file_extension == ".xls":
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in [".md", ".markdown"]:
elif file_extension in {".md", ".markdown"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
if is_automatic
else MarkdownExtractor(file_path, autodetect_encoding=True)
)
elif file_extension in [".htm", ".html"]:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
@ -134,13 +134,13 @@ class ExtractProcessor:
else TextExtractor(file_path, autodetect_encoding=True)
)
else:
if file_extension == ".xlsx" or file_extension == ".xls":
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in [".md", ".markdown"]:
elif file_extension in {".md", ".markdown"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in [".htm", ".html"]:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)

View File

@ -32,7 +32,7 @@ class FirecrawlApp:
else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
elif response.status_code in [402, 409, 500]:
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
else:

View File

@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor):
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select["name"])
elif type == "rich_text" or type == "title":
elif type in {"rich_text", "title"}:
if len(property_value[type]) > 0:
value = property_value[type][0]["plain_text"]
else:
value = ""
elif type == "select" or type == "status":
elif type in {"select", "status"}:
if property_value[type]:
value = property_value[type]["name"]
else:

View File

@ -115,7 +115,7 @@ class DatasetRetrieval:
available_datasets.append(dataset)
all_documents = []
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve(
app_id,
@ -426,7 +426,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=top_k,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

Some files were not shown because too many files have changed in this diff Show More