This commit is contained in:
Joel 2024-09-18 11:57:48 +08:00
commit 31d87f85b8
426 changed files with 4961 additions and 2340 deletions

9
.gitignore vendored
View File

@ -153,6 +153,9 @@ docker-legacy/volumes/etcd/*
docker-legacy/volumes/minio/*
docker-legacy/volumes/milvus/*
docker-legacy/volumes/chroma/*
docker-legacy/volumes/opensearch/data/*
docker-legacy/volumes/pgvectors/data/*
docker-legacy/volumes/pgvector/data/*
docker/volumes/app/storage/*
docker/volumes/certbot/*
@ -164,6 +167,12 @@ docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
docker/volumes/chroma/*
docker/volumes/opensearch/data/*
docker/volumes/myscale/data/*
docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf
docker/middleware.env

View File

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in ["console", "inner_api"]:
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")

View File

@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
@ -131,7 +131,7 @@ def reset_encrypt_key_pair():
click.echo(
click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in ["knowledge", "all"]:
if scope in {"knowledge", "all"}:
migrate_knowledge_vector_database()
if scope in ["annotation", "all"]:
if scope in {"annotation", "all"}:
migrate_annotation_vector_database()
@ -275,8 +275,7 @@ def migrate_knowledge_vector_database():
for dataset in datasets:
total_count = total_count + 1
click.echo(
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
)
try:
click.echo("Create dataset vdb index: {}".format(dataset.id))
@ -411,7 +410,8 @@ def migrate_knowledge_vector_database():
try:
click.echo(
click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
f"Start to created vector index with {len(documents)} documents of {segments_count}"
f" segments for dataset {dataset.id}.",
fg="green",
)
)
@ -593,7 +593,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
click.echo(
click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)

View File

@ -129,12 +129,12 @@ class EndpointConfig(BaseSettings):
)
SERVICE_API_URL: str = Field(
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
description="Service API Url prefix. used to display Service API Base Url to the front-end.",
default="",
)
APP_WEB_URL: str = Field(
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
default="",
)
@ -272,7 +272,7 @@ class LoggingConfig(BaseSettings):
"""
LOG_LEVEL: str = Field(
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
default="INFO",
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.8.0",
default="0.8.2",
)
COMMIT_SHA: str = Field(

View File

@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
site = app.site
if not site:
desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
else:
desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = (
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()

View File

@ -94,19 +94,15 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

View File

@ -29,10 +29,13 @@ class DailyMessageStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
FROM messages where app_id = :app_id
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(*) AS message_count
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -45,7 +48,7 @@ class DailyMessageStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -55,10 +58,10 @@ class DailyMessageStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -83,10 +86,13 @@ class DailyConversationStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
FROM messages where app_id = :app_id
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -99,7 +105,7 @@ class DailyConversationStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -109,10 +115,10 @@ class DailyConversationStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -137,10 +143,13 @@ class DailyTerminalsStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
FROM messages where app_id = :app_id
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -153,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -163,10 +172,10 @@ class DailyTerminalsStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -191,12 +200,14 @@ class DailyTokenCostStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
sum(total_price) as total_price
FROM messages where app_id = :app_id
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -209,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -219,10 +230,10 @@ class DailyTokenCostStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -249,12 +260,22 @@ class AverageSessionInteractionStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(subquery.message_count) AS interactions
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM conversations c
JOIN messages m ON c.id = m.conversation_id
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(subquery.message_count) AS interactions
FROM
(
SELECT
m.conversation_id,
COUNT(m.id) AS message_count
FROM
conversations c
JOIN
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -267,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and c.created_at >= :start"
sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -277,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and c.created_at < :end"
sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += """
GROUP BY m.conversation_id) subquery
LEFT JOIN conversations c on c.id=subquery.conversation_id
GROUP BY date
ORDER BY date"""
GROUP BY m.conversation_id
) subquery
LEFT JOIN
conversations c
ON c.id = subquery.conversation_id
GROUP BY
date
ORDER BY
date"""
response_data = []
@ -311,13 +337,17 @@ class UserSatisfactionRateStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
FROM messages m
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
WHERE m.app_id = :app_id
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count
FROM
messages m
LEFT JOIN
message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -330,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and m.created_at >= :start"
sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -340,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and m.created_at < :end"
sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -373,12 +403,13 @@ class AverageResponseTimeStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) as latency
FROM messages
WHERE app_id = :app_id
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) AS latency
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -391,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -401,10 +432,10 @@ class AverageResponseTimeStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -429,13 +460,16 @@ class TokensPerSecondStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second
FROM messages
WHERE app_id = :app_id"""
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -448,7 +482,7 @@ WHERE app_id = :app_id"""
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -458,10 +492,10 @@ WHERE app_id = :app_id"""
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []

View File

@ -465,6 +465,6 @@ api.add_resource(
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@ -30,12 +30,14 @@ class WorkflowDailyRunsStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -90,12 +92,14 @@ class WorkflowDailyTerminalsStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -150,14 +154,14 @@ class WorkflowDailyTokenCostStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) as token_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += " GROUP BY date ORDER BY date"
response_data = []
@ -217,23 +221,27 @@ class WorkflowAverageAppInteractionStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT
AVG(sub.interactions) as interactions,
sub.date
FROM
(SELECT
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM workflow_runs c
WHERE c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY date, c.created_by) sub
GROUP BY sub.date
"""
sql_query = """SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")

View File

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info)
# Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value:
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
# Create account
account_name = user_info.name if user_info.name else "Dify"
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)

View File

@ -399,7 +399,7 @@ class DatasetIndexingEstimateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
class DatasetRetrievalSettingApi(Resource):

View File

@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in ["completed", "error"]:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in ["completed", "error"]:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
# format document files info
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
elif action == "resume":
if document.indexing_status not in ["paused", "error"]:
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None

View File

@ -18,9 +18,7 @@ class NotSetupError(BaseHTTPException):
class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated"
description = (
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
description = "Init validation has not been completed yet. Please proceed with the init validation process first."
code = 401

View File

@ -81,19 +81,15 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

View File

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"],
"editable": current_user.role in {"owner", "admin"},
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
}
for installed_app in installed_apps

View File

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)

View File

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters."""
app_model = installed_app.app
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
)
api.add_resource(

View File

@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args["provider_name"] if args["provider_name"] else "",
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
args["parameters"],

View File

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError()
extension = file.filename.split(".")[-1]
if extension.lower() not in ["svg", "png"]:
if extension.lower() not in {"svg", "png"}:
raise UnsupportedFileTypeError()
try:

View File

@ -64,7 +64,8 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
source = request.args.get("source")
if source == "datasets":
abort(403, "The number of documents has reached the limit of your subscription.")

View File

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields)
def get(self, app_model: App):
"""Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -79,19 +79,15 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(

View File

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:

View File

@ -1,6 +1,7 @@
import logging
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@ -22,10 +23,12 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService
from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
@ -113,6 +116,30 @@ class WorkflowTaskStopApi(Resource):
return {"result": "success"}
class WorkflowAppLogApi(Resource):
@validate_app_token
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -78,19 +78,15 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None

View File

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError()
message_id = str(message_id)

View File

@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code):
if not source or source != "sso":
raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
# Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
source = decoded.get("token_source")
if source and source == "sso":

View File

@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),

View File

@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),

View File

@ -90,7 +90,7 @@ class CotAgentOutputParser:
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ["\n", " ", ""]:
if last_character not in {"\n", " ", ""}:
index += steps
yield delta
continue
@ -117,7 +117,7 @@ class CotAgentOutputParser:
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ["\n", " ", ""]:
if last_character not in {"\n", " ", ""}:
index += steps
yield delta
continue

View File

@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
{{historic_messages}}
Question: {{query}}
{{agent_scratchpad}}
Thought:"""
Thought:""" # noqa: E501
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""
@ -86,7 +87,8 @@ Action:
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
"""
""" # noqa: E501
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""

View File

@ -29,7 +29,7 @@ class BaseAppConfigManager:
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
)
additional_features.opening_statement, additional_features.suggested_questions = (

View File

@ -18,7 +18,7 @@ class AgentConfigManager:
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == "cot" or agent_strategy == "react":
elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
@ -43,10 +43,10 @@ class AgentConfigManager:
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",
"router",
]:
}:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")

View File

@ -167,7 +167,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key == "dataset":

View File

@ -86,7 +86,7 @@ class PromptTemplateConfigManager:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
)
model_mode_vals = [mode.value for mode in ModelMode]

View File

@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
variable=variable["variable"], type=variable["type"], config=variable["config"]
)
)
elif variable_type in [
elif variable_type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
}:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]
@ -115,7 +115,7 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
variables.append(form_item["variable"])

View File

@ -92,7 +92,7 @@ class VariableEntityType(str, Enum):
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
EXTERNAL_DATA_TOOL = "external_data_tool"
class VariableEntity(BaseModel):

View File

@ -54,14 +54,14 @@ class FileUploadConfigManager:
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in ["high", "low"]:
if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ["remote_url", "local_file"]:
if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]

View File

@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
raise ValueError("Workflow not initialized")
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:

View File

@ -22,11 +22,11 @@ class BaseAppGenerator:
return var.default or ""
if (
var.type
in (
in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
}
and user_input_value
and not isinstance(user_input_value, str)
):
@ -44,7 +44,7 @@ class BaseAppGenerator:
options = var.options or []
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")

View File

@ -32,7 +32,7 @@ class AppQueueManager:
self._user_id = user_id
self._invoke_from = invoke_from
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
@ -118,7 +118,7 @@ class AppQueueManager:
if result is None:
return
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return

View File

@ -161,7 +161,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else "",
query=query or "",
files=files,
context=context,
memory=memory,
@ -189,7 +189,7 @@ class AppRunner:
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=query if query else "",
query=query or "",
files=files,
context=context,
memory_config=memory_config,
@ -238,7 +238,7 @@ class AppRunner:
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage(),
usage=usage or LLMUsage.empty_usage(),
),
),
PublishFrom.APPLICATION_MANAGER,
@ -351,7 +351,7 @@ class AppRunner:
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=query if query else "",
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)
@ -379,7 +379,7 @@ class AppRunner:
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream,
)

View File

@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api"
end_user_id = application_generate_entity.user_id
else:
@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model
override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
]:
}:
override_model_configs = app_config.app_model_config_dict
# get conversation introduction

View File

@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,

View File

@ -84,10 +84,12 @@ class WorkflowLoggingCallback(WorkflowCallback):
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="green",
)
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
@ -114,14 +116,17 @@ class WorkflowLoggingCallback(WorkflowCallback):
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
)
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:

View File

@ -63,7 +63,7 @@ class AnnotationReplyFeature:
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api"
else:
from_source = "console"

View File

@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)

View File

@ -383,7 +383,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
response = NodeStartStreamResponse(
@ -430,7 +430,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
return NodeFinishStreamResponse(

View File

@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
source="app",
source_app_id=self._app_id,
created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
),
created_by=self._user_id,
)

View File

@ -65,7 +65,7 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except Exception as e:
logging.exception("Failed transform embedding: ", e)
logging.exception("Failed transform embedding: %s", e)
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
@ -85,7 +85,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.error("Failed to embed documents: ", ex)
logger.error("Failed to embed documents: %s", ex)
raise ex
return text_embeddings
@ -116,10 +116,7 @@ class CacheEmbedding(Embeddings):
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except IntegrityError:
db.session.rollback()
except:
logging.exception("Failed to add embedding to redis")
except Exception as ex:
logging.exception("Failed to add embedding to redis %s", ex)
return embedding_results

View File

@ -3,6 +3,7 @@ import importlib.util
import json
import logging
import os
from pathlib import Path
from typing import Any, Optional
from pydantic import BaseModel
@ -63,8 +64,7 @@ class Extensible:
builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding="utf-8") as f:
position = int(f.read().strip())
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position
if (extension_name + ".py") not in file_names:

View File

@ -188,7 +188,8 @@ class MessageFileParser:
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
def is_s3_presigned_url(url):

View File

@ -89,7 +89,8 @@ class CodeExecutor:
raise CodeExecutionError("Code execution service is unavailable")
elif response.status_code != 200:
raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
f"Failed to execute code, got status code {response.status_code},"
f" please check if the sandbox service is running"
)
except CodeExecutionError as e:
raise e

View File

@ -14,7 +14,10 @@ class ToolParameterCache:
def __init__(
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
self.cache_key = (
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f":identity_id:{identity_id}"
)
def get(self) -> Optional[dict]:
"""

View File

@ -292,7 +292,7 @@ class IndexingRunner:
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return []
data_source_info = dataset_document.data_source_info_dict

View File

@ -59,24 +59,27 @@ User Input: yo, 你今天咋样?
}
User Input:
"""
""" # noqa: E501
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response"
"(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
"The output must be an array in JSON format following the specified schema:\n"
'["question1","question2","question3"]\n'
)
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step."
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."
"Step 1: Understand and summarize the main content of this text.\n"
"Step 2: What key information or concepts are mentioned in this text?\n"
"Step 3: Decompose or combine multiple pieces of information and concepts.\n"
"Step 4: Generate questions and answers based on these key information and concepts.\n"
"<Constraints> The questions should be clear and detailed, and the answers should be detailed and complete. "
"You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n"
"You must answer in {language}, in a style that is clear and detailed in {language}."
" No language other than {language} should be used. \n"
"<Format> Use the following format: Q1:\nA1:\nQ2:\nA2:...\n"
"<QA Pairs>"
)
@ -94,7 +97,7 @@ Based on task description, please create a well-structured prompt template that
- Use the same language as task description.
- Output in ``` xml ``` and start with <instruction>
Please generate the full prompt template with at least 300 words and output only the prompt template.
"""
""" # noqa: E501
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
Here is a task description for which I would like you to create a high-quality prompt template for:
@ -109,7 +112,7 @@ Based on task description, please create a well-structured prompt template that
- Use the same language as task description.
- Output in ``` xml ``` and start with <instruction>
Please generate the full prompt template and output only the prompt template.
"""
""" # noqa: E501
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
@ -134,7 +137,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
### Answer
I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text.
"""
""" # noqa: E501
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """
<instruction>
@ -150,4 +153,4 @@ Welcome! I'm here to assist you with any questions or issues you might have with
Here is the task description: {{INPUT_TEXT}}
You just need to generate the output
"""
""" # noqa: E501

View File

@ -39,7 +39,7 @@ class TokenBufferMemory:
)
if message_limit and message_limit > 0:
message_limit = message_limit if message_limit <= 500 else 500
message_limit = min(message_limit, 500)
else:
message_limit = 500
@ -52,7 +52,7 @@ class TokenBufferMemory:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
if message.workflow_run_id:

View File

@ -8,8 +8,11 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
"type": "float",
"help": {
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
"en_US": "Controls randomness. Lower temperature results in less random completions."
" As the temperature approaches zero, the model will become deterministic and repetitive."
" Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
"较高的温度会导致更多的随机完成。",
},
"required": False,
"default": 0.0,
@ -24,7 +27,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
"type": "float",
"help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
" are considered.",
"zh_Hans": "通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。",
},
"required": False,
@ -88,7 +92,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
"type": "int",
"help": {
"en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
"en_US": "Specifies the upper limit on the length of generated results."
" If the generated results are truncated, you can increase this parameter.",
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
},
"required": False,
@ -104,7 +109,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
"type": "string",
"help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
" such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等",
},
"required": False,

View File

@ -27,17 +27,17 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
if origin_model_type in {"text-generation", cls.LLM.value}:
return cls.LLM
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
return cls.TEXT_EMBEDDING
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
elif origin_model_type in {"reranking", cls.RERANK.value}:
return cls.RERANK
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
return cls.SPEECH2TEXT
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
elif origin_model_type in {"tts", cls.TTS.value}:
return cls.TTS
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION

View File

@ -72,7 +72,9 @@ class AIModel(ABC):
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return invoke_error(
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
description=(
f"[{provider_name}] Incorrect model credentials provided, please check and try again."
)
)
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
@ -198,7 +200,7 @@ class AIModel(ABC):
except Exception as e:
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
raise Exception(
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" f" {str(e)}"
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}: {str(e)}"
)
# cache model schema

View File

@ -187,7 +187,7 @@ if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
""" # noqa: E501
code_block = model_parameters.get("response_format", "")
if not code_block:
@ -449,7 +449,7 @@ if you are not sure about the structure.
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,
@ -830,7 +830,8 @@ if you are not sure about the structure.
else:
if parameter_value != round(parameter_value, parameter_rule.precision):
raise ValueError(
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}"
f" decimal places."
)
# validate parameter value range

View File

@ -51,7 +51,7 @@ if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
""" # noqa: E501
class AnthropicLargeLanguageModel(LargeLanguageModel):
@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
),
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ""
chunk_text = chunk.delta.text or ""
full_assistant_content += chunk_text
# transform assistant message to prompt message
@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,

View File

@ -16,6 +16,15 @@ from core.model_runtime.entities.model_entities import (
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
AZURE_DEFAULT_PARAM_SEED_HELP = I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,"
"您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically,"
" such that repeated requests with the same seed and parameters should return the same result."
" Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter"
" to monitor changes in the backend.",
)
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
rule = ParameterRule(
@ -229,10 +238,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -297,10 +303,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -365,10 +368,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -433,10 +433,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -502,10 +499,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -571,10 +565,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -650,10 +641,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -719,10 +707,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -788,10 +773,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -867,10 +849,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -936,10 +915,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
@ -1000,10 +976,7 @@ LLM_BASE_MODELS = [
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,

View File

@ -53,6 +53,12 @@ model_credential_schema:
type: select
required: true
options:
- label:
en_US: 2024-08-01-preview
value: 2024-08-01-preview
- label:
en_US: 2024-07-01-preview
value: 2024-07-01-preview
- label:
en_US: 2024-05-01-preview
value: 2024-05-01-preview

View File

@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
continue
# transform assistant message to prompt message
text = delta.text if delta.text else ""
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
full_assistant_content += delta.delta.content if delta.delta.content else ""
full_assistant_content += delta.delta.content or ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content if delta.delta.content else ""
completion += delta.delta.content or ""
yield LLMResultChunk(
model=real_model,

View File

@ -84,15 +84,15 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
)
for i in range(len(sentences))
]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
for future in futures:
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else:
response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip()
)
yield from response.__enter__().iter_bytes(1024)
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -33,7 +33,7 @@ parameter_rules:
- name: res_format
label:
zh_Hans: 回复格式
en_US: response format
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式

View File

@ -33,7 +33,7 @@ parameter_rules:
- name: res_format
label:
zh_Hans: 回复格式
en_US: response format
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式

View File

@ -33,7 +33,7 @@ parameter_rules:
- name: res_format
label:
zh_Hans: 回复格式
en_US: response format
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式

View File

@ -15,6 +15,7 @@ class BaichuanTokenizer:
@classmethod
def _get_num_tokens(cls, text: str) -> int:
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
# tokens = number of Chinese characters + number of English words * 1.3
# (for estimation only, subject to actual return)
# https://platform.baichuan-ai.com/docs/text-Embedding
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)

View File

@ -0,0 +1,59 @@
model: eu.anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00025'
output: '0.00125'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,58 @@
model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0
label:
en_US: Claude 3.5 Sonnet(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,58 @@
model: eu.anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -1,8 +1,8 @@
# standard import
import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
@ -17,7 +17,6 @@ from botocore.exceptions import (
ServiceNotInRegionError,
UnknownServiceError,
)
from PIL.Image import Image
# local import
from core.model_runtime.callbacks.base_callback import Callback
@ -52,7 +51,7 @@ if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
""" # noqa: E501
class BedrockLargeLanguageModel(LargeLanguageModel):
@ -61,6 +60,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
CONVERSE_API_ENABLED_MODEL_INFO = [
{"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
@ -331,10 +332,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"]
if "text" in delta:
chunk_text = delta["text"] if delta["text"] else ""
chunk_text = delta["text"] or ""
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else "",
content=chunk_text or "",
)
index = chunk["contentBlockDelta"]["contentBlockIndex"]
yield LLMResultChunk(
@ -441,8 +442,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
try:
url = message_content.data
image_content = requests.get(url).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
@ -452,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
@ -541,7 +543,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"max_tokens": 32,
}
elif "ai21" in model:
# ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
# ValidationException: Malformed input request: #/temperature: expected type: Number,
# found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null,
# please reformat your input and try again.
required_params = {
"temperature": 0.7,
"topP": 0.9,
@ -749,7 +753,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif model_prefix == "cohere":
output = response_body.get("generations")[0].get("text")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
completion_tokens = self.get_num_tokens(model, credentials, output or "")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@ -826,7 +830,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=content_delta if content_delta else "",
content=content_delta or "",
)
index += 1
@ -882,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg)
elif error_code in [
elif error_code in {
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
}:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

View File

@ -0,0 +1,59 @@
model: us.anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00025'
output: '0.00125'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,59 @@
model: us.anthropic.claude-3-opus-20240229-v1:0
label:
en_US: Claude 3 Opus(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.015'
output: '0.075'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,58 @@
model: us.anthropic.claude-3-5-sonnet-20240620-v1:0
label:
en_US: Claude 3.5 Sonnet(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,58 @@
model: us.anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg)
elif error_code in [
elif error_code in {
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
}:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

View File

@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if delta.delta.function_call:
function_calls = [delta.delta.function_call]
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
)
if delta.finish_reason is not None:

View File

@ -621,7 +621,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
desc = p_val["description"]
if "enum" in p_val:
desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]"
desc += f"; Only accepts one of the following predefined options: [{', '.join(p_val['enum'])}]"
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
description=desc, type=p_val["type"], required=required

View File

@ -62,7 +62,7 @@ parameter_rules:
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式

View File

@ -6,10 +6,10 @@ from collections.abc import Generator
from typing import Optional, Union, cast
import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions
import google.generativeai as genai
import google.generativeai.client as client
import requests
from google.api_core import exceptions
from google.generativeai import client
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part
from PIL import Image
@ -45,7 +45,7 @@ if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel):

View File

@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
if "huggingfacehub_api_token" not in credentials:
@ -94,9 +94,9 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
credentials["huggingfacehub_api_token"], model
)
if credentials["task_type"] not in ("text2text-generation", "text-generation"):
if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be one of text2text-generation, " "text-generation."
"Huggingface Hub Task Type must be one of text2text-generation, text-generation."
)
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
@ -282,7 +282,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
valid_tasks = ("text2text-generation", "text-generation")
if model_info.pipeline_tag not in valid_tasks:
raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.")
raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")

View File

@ -121,7 +121,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
valid_tasks = "feature-extraction"
if model_info.pipeline_tag not in valid_tasks:
raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.")
raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")

View File

@ -49,8 +49,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
return RerankResult(model=model, docs=[])
server_url = credentials["server_url"]
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
try:
results = TeiHelper.invoke_rerank(server_url, query, docs)

View File

@ -54,7 +54,8 @@ class TeiHelper:
url = str(URL(server_url) / "info")
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
# this method is surrounded by a lock, and default requests may hang forever,
# so we just set a Adapter with max_retries=3
session = Session()
session.mount("http://", HTTPAdapter(max_retries=3))
session.mount("https://", HTTPAdapter(max_retries=3))
@ -74,7 +75,7 @@ class TeiHelper:
if len(model_type.keys()) < 1:
raise RuntimeError("model_type is empty")
model_type = list(model_type.keys())[0]
if model_type not in ["embedding", "reranker"]:
if model_type not in {"embedding", "reranker"}:
raise RuntimeError(f"invalid model_type: {model_type}")
max_input_length = response_json.get("max_input_length", 512)

View File

@ -42,8 +42,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
"""
server_url = credentials["server_url"]
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
# get model properties
context_size = self._get_context_size(model, credentials)
@ -119,8 +118,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
num_tokens = 0
server_url = credentials["server_url"]
if server_url.endswith("/"):
server_url = server_url[:-1]
server_url = server_url.removesuffix("/")
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
num_tokens = sum(len(tokens) for tokens in batch_tokens)

View File

@ -2,3 +2,4 @@
- hunyuan-standard
- hunyuan-standard-256k
- hunyuan-pro
- hunyuan-turbo

Some files were not shown because too many files have changed in this diff Show More