diff --git a/api/app.py b/api/app.py index ad219ca0d6..91a49337fc 100644 --- a/api/app.py +++ b/api/app.py @@ -164,7 +164,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ["console", "inner_api"]: + if request.blueprint not in {"console", "inner_api"}: return None # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get("Authorization", "") diff --git a/api/commands.py b/api/commands.py index 3bf8bc0ecc..3a6b4963cf 100644 --- a/api/commands.py +++ b/api/commands.py @@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm): ) @click.confirmation_option( prompt=click.style( - "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" + "Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red" ) ) def reset_encrypt_key_pair(): @@ -131,7 +131,7 @@ def reset_encrypt_key_pair(): click.echo( click.style( - "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id), fg="green", ) ) @@ -140,9 +140,9 @@ def reset_encrypt_key_pair(): @click.command("vdb-migrate", help="migrate vector db.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ["knowledge", "all"]: + if scope in {"knowledge", "all"}: migrate_knowledge_vector_database() - if scope in ["annotation", "all"]: + if scope in {"annotation", "all"}: migrate_annotation_vector_database() @@ -275,8 +275,7 @@ def migrate_knowledge_vector_database(): for dataset in datasets: total_count = total_count + 1 click.echo( - f"Processing the {total_count} dataset {dataset.id}. " - + f"{create_count} created, {skipped_count} skipped." + f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." ) try: click.echo("Create dataset vdb index: {}".format(dataset.id)) @@ -411,7 +410,8 @@ def migrate_knowledge_vector_database(): try: click.echo( click.style( - f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + f"Start to created vector index with {len(documents)} documents of {segments_count}" + f" segments for dataset {dataset.id}.", fg="green", ) ) @@ -593,7 +593,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str click.echo( click.style( - "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), + "Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password), fg="green", ) ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 903f66a53d..4e1dfe73ad 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -110,6 +110,7 @@ class CodeExecutionSandboxConfig(BaseSettings): default=1000, ) + class PluginConfig(BaseSettings): """ Plugin configs @@ -124,6 +125,7 @@ class PluginConfig(BaseSettings): default='dify-inner-api-key', ) + class EndpointConfig(BaseSettings): """ Module URL configs @@ -142,12 +144,12 @@ class EndpointConfig(BaseSettings): ) SERVICE_API_URL: str = Field( - description="Service API Url prefix." "used to display Service API Base Url to the front-end.", + description="Service API Url prefix. used to display Service API Base Url to the front-end.", default="", ) APP_WEB_URL: str = Field( - description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.", + description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.", default="", ) @@ -285,7 +287,7 @@ class LoggingConfig(BaseSettings): """ LOG_LEVEL: str = Field( - description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.", + description="Log output level, default to INFO. It is recommended to set it to ERROR for production.", default="INFO", ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index e03dfeb27c..3815a6fca2 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.8.0", + default="0.8.2", ) COMMIT_SHA: str = Field( diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a4ceec2662..f78ea9b288 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource): site = app.site if not site: - desc = args["desc"] if args["desc"] else "" - copy_right = args["copyright"] if args["copyright"] else "" - privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" - custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" + desc = args["desc"] or "" + copy_right = args["copyright"] or "" + privacy_policy = args["privacy_policy"] or "" + custom_disclaimer = args["custom_disclaimer"] or "" else: - desc = site.description if site.description else args["desc"] if args["desc"] else "" - copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" - privacy_policy = ( - site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" - ) - custom_disclaimer = ( - site.custom_disclaimer - if site.custom_disclaimer - else args["custom_disclaimer"] - if args["custom_disclaimer"] - else "" - ) + desc = site.description or args["desc"] or "" + copy_right = site.copyright or args["copyright"] or "" + privacy_policy = site.privacy_policy or args["privacy_policy"] or "" + custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 3f5e1adca2..35ac42a14c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource): def post(self, resource_id): resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() current_key_count = ( diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 437a6a7b38..c1ef05a488 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -94,19 +94,15 @@ class ChatMessageTextApi(Resource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 46c0b22993..df7bd352af 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -20,7 +20,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation @@ -36,8 +36,8 @@ class CompletionConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) @@ -143,8 +143,8 @@ class ChatConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 81826a20d0..3ef442812d 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode @@ -25,14 +25,17 @@ class DailyMessageStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count - FROM messages where app_id = :app_id - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(*) AS message_count +FROM + messages +WHERE + app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -45,7 +48,7 @@ class DailyMessageStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -55,10 +58,10 @@ class DailyMessageStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -79,14 +82,17 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count - FROM messages where app_id = :app_id - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(DISTINCT messages.conversation_id) AS conversation_count +FROM + messages +WHERE + app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -99,7 +105,7 @@ class DailyConversationStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -109,10 +115,10 @@ class DailyConversationStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -133,14 +139,17 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count - FROM messages where app_id = :app_id - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(DISTINCT messages.from_end_user_id) AS terminal_count +FROM + messages +WHERE + app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -153,7 +162,7 @@ class DailyTerminalsStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -163,10 +172,10 @@ class DailyTerminalsStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -187,16 +196,18 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, - sum(total_price) as total_price - FROM messages where app_id = :app_id - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + (SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, + SUM(total_price) AS total_price +FROM + messages +WHERE + app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -209,7 +220,7 @@ class DailyTokenCostStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -219,10 +230,10 @@ class DailyTokenCostStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -245,16 +256,26 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, -AVG(subquery.message_count) AS interactions -FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count - FROM conversations c - JOIN messages m ON c.id = m.conversation_id - WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" + sql_query = """SELECT + DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + AVG(subquery.message_count) AS interactions +FROM + ( + SELECT + m.conversation_id, + COUNT(m.id) AS message_count + FROM + conversations c + JOIN + messages m + ON c.id = m.conversation_id + WHERE + c.override_model_configs IS NULL + AND c.app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -267,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and c.created_at >= :start" + sql_query += " AND c.created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -277,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and c.created_at < :end" + sql_query += " AND c.created_at < :end" arg_dict["end"] = end_datetime_utc sql_query += """ - GROUP BY m.conversation_id) subquery -LEFT JOIN conversations c on c.id=subquery.conversation_id -GROUP BY date -ORDER BY date""" + GROUP BY m.conversation_id + ) subquery +LEFT JOIN + conversations c + ON c.id = subquery.conversation_id +GROUP BY + date +ORDER BY + date""" response_data = [] @@ -307,17 +333,21 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count - FROM messages m - LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' - WHERE m.app_id = :app_id - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(m.id) AS message_count, + COUNT(mf.id) AS feedback_count +FROM + messages m +LEFT JOIN + message_feedbacks mf + ON mf.message_id=m.id AND mf.rating='like' +WHERE + m.app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -330,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and m.created_at >= :start" + sql_query += " AND m.created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -340,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and m.created_at < :end" + sql_query += " AND m.created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -369,16 +399,17 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - AVG(provider_response_latency) as latency - FROM messages - WHERE app_id = :app_id - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + AVG(provider_response_latency) AS latency +FROM + messages +WHERE + app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -391,7 +422,7 @@ class AverageResponseTimeStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -401,10 +432,10 @@ class AverageResponseTimeStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -425,17 +456,20 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - CASE + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second -FROM messages -WHERE app_id = :app_id""" +FROM + messages +WHERE + app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) @@ -448,7 +482,7 @@ WHERE app_id = :app_id""" start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -458,10 +492,10 @@ WHERE app_id = :app_id""" end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 20ec3ef021..42f97c1699 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -502,6 +502,6 @@ api.add_resource( api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") api.add_resource( - DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs" "/" + DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs/" ) api.add_resource(ConvertToWorkflowApi, "/apps//convert-to-workflow") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index db2f683589..c7e54f2be0 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode from models.workflow import WorkflowRunTriggeredFrom @@ -26,16 +26,18 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs - FROM workflow_runs - WHERE app_id = :app_id - AND triggered_from = :triggered_from - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(id) AS runs +FROM + workflow_runs +WHERE + app_id = :app_id + AND triggered_from = :triggered_from""" arg_dict = { "tz": account.timezone, "app_id": app_model.id, @@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -86,16 +88,18 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count - FROM workflow_runs - WHERE app_id = :app_id - AND triggered_from = :triggered_from - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(DISTINCT workflow_runs.created_by) AS terminal_count +FROM + workflow_runs +WHERE + app_id = :app_id + AND triggered_from = :triggered_from""" arg_dict = { "tz": account.timezone, "app_id": app_model.id, @@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -146,18 +150,18 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT - date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - SUM(workflow_runs.total_tokens) as token_count - FROM workflow_runs - WHERE app_id = :app_id - AND triggered_from = :triggered_from - """ + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + SUM(workflow_runs.total_tokens) AS token_count +FROM + workflow_runs +WHERE + app_id = :app_id + AND triggered_from = :triggered_from""" arg_dict = { "tz": account.timezone, "app_id": app_model.id, @@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at >= :start" + sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc if args["end"]: @@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " and created_at < :end" + sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date order by date" + sql_query += " GROUP BY date ORDER BY date" response_data = [] @@ -213,27 +217,31 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT - AVG(sub.interactions) as interactions, - sub.date - FROM - (SELECT - date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - c.created_by, - COUNT(c.id) AS interactions - FROM workflow_runs c - WHERE c.app_id = :app_id - AND c.triggered_from = :triggered_from - {{start}} - {{end}} - GROUP BY date, c.created_by) sub - GROUP BY sub.date - """ + sql_query = """SELECT + AVG(sub.interactions) AS interactions, + sub.date +FROM + ( + SELECT + DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + c.created_by, + COUNT(c.id) AS interactions + FROM + workflow_runs c + WHERE + c.app_id = :app_id + AND c.triggered_from = :triggered_from + {{start}} + {{end}} + GROUP BY + date, c.created_by + ) sub +GROUP BY + sub.date""" arg_dict = { "tz": account.timezone, "app_id": app_model.id, @@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") + sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end") arg_dict["end"] = end_datetime_utc else: sql_query = sql_query.replace("{{end}}", "") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8ba6b53e7e..f3198dfc1d 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -8,7 +8,7 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from libs.helper import email, str_len, timezone +from libs.helper import StrLen, email, timezone from libs.password import hash_password, valid_password from models.account import AccountStatus from services.account_service import RegisterService @@ -37,7 +37,7 @@ class ActivateApi(Resource): parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument( "interface_language", type=supported_language, required=True, nullable=False, location="json" diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ae1b49f3ec..ad0c0580ae 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -71,7 +71,7 @@ class OAuthCallback(Resource): account = _generate_account(provider, user_info) # Check account status - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: # Create account - account_name = user_info.name if user_info.name else "Dify" + account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 6ccacc78ee..2c4e5ac607 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -399,7 +399,7 @@ class DatasetIndexingEstimateApi(Resource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return { - "api_base_url": ( - dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") - ) - + "/v1" - } + return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} class DatasetRetrievalSettingApi(Resource): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 076f3cd44d..829ef11e52 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info @@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource): db.session.commit() elif action == "resume": - if document.indexing_status not in ["paused", "error"]: + if document.indexing_status not in {"paused", "error"}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 1c70ea6c59..870e547728 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -18,9 +18,7 @@ class NotSetupError(BaseHTTPException): class NotInitValidateError(BaseHTTPException): error_code = "not_init_validated" - description = ( - "Init validation has not been completed yet. " "Please proceed with the init validation process first." - ) + description = "Init validation has not been completed yet. Please proceed with the init validation process first." code = 401 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 71cb060ecc..9690677f61 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -81,19 +81,15 @@ class ChatTextApi(InstalledAppResource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index c039e8bca5..f464692098 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 2918024b64..6f9d7769b9 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource): def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource): def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3f1e64a247..408afc33a0 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource): "app_owner_tenant_id": installed_app.app_owner_tenant_id, "is_pinned": installed_app.is_pinned, "last_used_at": installed_app.last_used_at, - "editable": current_user.role in ["owner", "admin"], + "editable": current_user.role in {"owner", "admin"}, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index f5eb185172..0e0238556c 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() message_id = str(message_id) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index ad55b04043..aab7dd7888 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource): """Retrieve app parameters.""" app_model = installed_app.app - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 7d3ae677ee..ae759bb752 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -4,7 +4,7 @@ from flask import session from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import str_len +from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -28,7 +28,7 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument("password", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=StrLen(30), required=True, location="json") input_password = parser.parse_args()["password"] if input_password != os.environ.get("INIT_PASSWORD"): diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 827695e00f..46b4ef5d87 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask import request from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import email, get_remote_ip, str_len +from libs.helper import StrLen, email, get_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService @@ -40,7 +40,7 @@ class SetupApi(Resource): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("name", type=StrLen(30), required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 8c38420226..fe0bcf7338 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") api.add_resource( - ModelProviderIconApi, "/workspaces/current/model-providers//" "/" + ModelProviderIconApi, "/workspaces/current/model-providers///" ) api.add_resource( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c41a898fdc..d2a17b133b 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource): return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args["provider_name"] if args["provider_name"] else "", + args["provider_name"] or "", args["tool_name"], args["credentials"], args["parameters"], diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 623f0b8b74..af3ebc099b 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource): raise TooManyFilesError() extension = file.filename.split(".")[-1] - if extension.lower() not in ["svg", "png"]: + if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 7667b30e34..46223d104f 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -64,7 +64,8 @@ def cloud_edition_billing_resource_check(resource: str): elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: abort(403, "The capacity of the vector space has reached the limit of your subscription.") elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets + # The api of file upload is used in the multiple places, + # so we need to check the source of the request from datasets source = request.args.get("source") if source == "datasets": abort(403, "The number of documents has reached the limit of your subscription.") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 70a74e73e7..4c28e6acb3 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -38,6 +38,7 @@ class PluginInvokeLLMApi(Resource): return compact_generate_response(generator()) + class PluginInvokeTextEmbeddingApi(Resource): @setup_required @plugin_inner_api_only @@ -113,6 +114,7 @@ class PluginInvokeNodeApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode): pass + class PluginInvokeAppApi(Resource): @setup_required @plugin_inner_api_only @@ -134,6 +136,7 @@ class PluginInvokeAppApi(Resource): PluginAppBackwardsInvocation.convert_to_event_stream(response) ) + class PluginInvokeEncryptApi(Resource): @setup_required @plugin_inner_api_only @@ -145,6 +148,7 @@ class PluginInvokeEncryptApi(Resource): """ return PluginEncrypter.invoke_encrypt(tenant_model, payload) + api.add_resource(PluginInvokeLLMApi, '/invoke/llm') api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') api.add_resource(PluginInvokeRerankApi, '/invoke/rerank') diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index eb753218ff..aa1f36d33c 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -48,6 +48,7 @@ def get_tenant(view: Optional[Callable] = None): else: return decorator(view) + def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): def decorator(view_func): def decorated_view(*args, **kwargs): diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 235359a9bb..3e184515d3 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -63,6 +63,7 @@ def enterprise_inner_api_user_auth(view): return decorated + def plugin_inner_api_only(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecc2d73deb..f7c091217b 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -42,7 +42,7 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 85aab047a7..5db4163647 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -79,19 +79,15 @@ class TextApi(Resource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index f1771baf31..8d8e356c4c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -96,7 +96,7 @@ class ChatApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -144,7 +144,7 @@ class ChatStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 734027a1c5..527ef4ecd3 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -18,7 +18,7 @@ class ConversationApi(Resource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -52,7 +52,7 @@ class ConversationDetailApi(Resource): @marshal_with(simple_conversation_fields) def delete(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -69,7 +69,7 @@ class ConversationRenameApi(Resource): @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b39aaf7dd8..e54e6f4903 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -76,7 +76,7 @@ class MessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource): def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() try: diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 5822e0921b..96d1337632 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,6 +1,7 @@ import logging from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restful.inputs import int_range from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -22,10 +23,12 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from services.app_generate_service import AppGenerateService +from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) @@ -113,6 +116,30 @@ class WorkflowTaskStopApi(Resource): return {"result": "success"} +class WorkflowAppLogApi(Resource): + @validate_app_token + @marshal_with(workflow_app_log_pagination_fields) + def get(self, app_model: App): + """ + Get workflow app logs + """ + parser = reqparse.RequestParser() + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") + args = parser.parse_args() + + # get paginate workflow app logs + workflow_app_service = WorkflowAppService() + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( + app_model=app_model, args=args + ) + + return workflow_app_log_pagination + + api.add_resource(WorkflowRunApi, "/workflows/run") api.add_resource(WorkflowRunDetailApi, "/workflows/run/") api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") +api.add_resource(WorkflowAppLogApi, "/workflows/logs") diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index aabca93338..20b4e4674c 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index d062d2893b..23550efe2e 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -78,19 +78,15 @@ class TextApi(WebApiResource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0837eedfb0..115492b796 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): def post(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -136,7 +136,7 @@ class ChatApi(WebApiResource): class ChatStopApi(WebApiResource): def post(self, app_model, end_user, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 6bbfa94c27..c3b0cd4f44 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource): class ConversationApi(WebApiResource): def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource): @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource): class ConversationPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 56aaaa930a..0d4047f4ef 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -78,7 +78,7 @@ class MessageListApi(WebApiResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotCompletionAppError() message_id = str(message_id) diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 93dc691d62..c327c3df18 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code): if not source or source != "sso": raise WebSSOAuthRequiredError() - # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login + # Check if SSO is not enforced for web, and if the token source is SSO, + # raise an error and redirect to normal passport login if not system_features.sso_enforced_for_web or not app_web_sso_enabled: source = decoded.get("token_source") if source and source == "sso": diff --git a/api/core/__init__.py b/api/core/__init__.py index 8c986fc8bd..6eaea7b1c8 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +1 @@ -import core.moderation.base \ No newline at end of file +import core.moderation.base diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 89c948d2e2..ebe04bf260 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,17 +25,19 @@ from models.model import Message class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True - _ignore_observation_providers = ['wenxin'] + _ignore_observation_providers = ["wenxin"] _historic_prompt_messages: list[PromptMessage] = None _agent_scratchpad: list[AgentScratchpadUnit] = None _instruction: str = None _query: str = None _prompt_messages_tools: list[PromptMessage] = None - def run(self, message: Message, - query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + def run( + self, + message: Message, + query: str, + inputs: dict[str, str], + ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ @@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC): trace_manager = app_generate_entity.trace_manager # check model mode - if 'Observation' not in app_generate_entity.model_conf.stop: + if "Observation" not in app_generate_entity.model_conf.stop: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append('Observation') + app_generate_entity.model_conf.stop.append("Observation") app_config = self.app_config # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools( - instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 @@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # recalc llm max tokens prompt_messages = self._organize_prompt_messages() @@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output( - chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( - agent_response='', - thought='', - action_str='', - observation='', + agent_response="", + thought="", + action_str="", + observation="", action=None, ) # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) for chunk in react_chunks: if isinstance(chunk, AgentScratchpadUnit.Action): @@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC): yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=chunk - ), - usage=None - ) + system_fingerprint="", + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - scratchpad.thought = scratchpad.thought.strip( - ) or 'I am thinking about how to help you' + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) # get llm usage - if 'usage' in usage_dict: - increase_usage(llm_usage, usage_dict['usage']) + if "usage" in usage_dict: + increase_usage(llm_usage, usage_dict["usage"]) else: - usage_dict['usage'] = LLMUsage.empty_usage() + usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else {}, + tool_name=scratchpad.action.action_name if scratchpad.action else "", + tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, thought=scratchpad.thought, - observation='', + observation="", answer=scratchpad.agent_response, messages_ids=[], - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) if not scratchpad.is_final(): - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) if not scratchpad.action: # failed to extract action, return final answer directly - final_answer = '' + final_answer = "" else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps( - scratchpad.action.action_input) + final_answer = json.dumps(scratchpad.action.action_input) elif isinstance(scratchpad.action.action_input, str): final_answer = scratchpad.action.action_input else: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" except json.JSONDecodeError: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True # action is tool call, invoke tool @@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input}, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought, - observation={ - scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={ - scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, messages_ids=message_file_ids, - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # update prompt tool message for prompt_tool in self._prompt_messages_tools: @@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC): model=model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=final_answer - ), - usage=llm_usage['usage'] + index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] ), - system_fingerprint='' + system_fingerprint="", ) # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name='', + tool_name="", tool_input={}, tool_invoke_meta={}, thought=final_answer, observation={}, answer=final_answer, - messages_ids=[] + messages_ids=[], ) self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) - def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, - tool_instances: dict[str, Tool], - message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None - ) -> tuple[str, ToolInvokeMeta]: + def _handle_invoke_action( + self, + action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str], + trace_manager: Optional[TraceQueueManager] = None, + ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action :param action: action @@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file( - tool_name=tool_call_name, value=message_file_id, name=save_as) + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) @@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ convert dict to action """ - return AgentScratchpadUnit.Action( - action_name=action['action'], - action_input=action['action_input'] - ) + return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ for key, value in inputs.items(): try: - instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) except Exception as e: continue @@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): @abstractmethod def _organize_prompt_messages(self) -> list[PromptMessage]: """ - organize prompt messages + organize prompt messages """ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ - format assistant message + format assistant message """ - message = '' + message = "" for scratchpad in agent_scratchpad: if scratchpad.is_final(): message += f"Final Answer: {scratchpad.agent_response}" @@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): return message - def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_historic_prompt_messages( + self, current_session_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ - organize historic prompt messages + organize historic prompt messages """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] @@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not current_scratchpad: current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content or 'I am thinking about how to help you', - action_str='', + thought=message.content or "I am thinking about how to help you", + action_str="", action=None, observation=None, ) @@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC): try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, - action_input=json.loads( - message.tool_calls[0].function.arguments) - ) - current_scratchpad.action_str = json.dumps( - current_scratchpad.action.to_dict() + action_input=json.loads(message.tool_calls[0].function.arguments), ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) except: pass elif isinstance(message, ToolPromptMessage): @@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): current_scratchpad.observation = message.content elif isinstance(message, UserPromptMessage): if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) scratchpads = [] current_scratchpad = None result.append(message) if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) historic_prompts = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=current_session_messages or [], history_messages=result, - memory=self.memory + memory=self.memory, ).get_prompt() return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 8debbe5c5d..bdec6b7ed1 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner): prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt \ - .replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) return SystemPromptMessage(content=system_prompt) - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner): def _organize_prompt_messages(self) -> list[PromptMessage]: """ - Organize + Organize """ # organize system prompt system_message = self._organize_system_prompt() @@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner): if not agent_scratchpad: assistant_messages = [] else: - assistant_message = AssistantPromptMessage(content='') + assistant_message = AssistantPromptMessage(content="") for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" @@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner): if assistant_messages: # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([ - system_message, - *query_messages, - *assistant_messages, - UserPromptMessage(content='continue') - ]) + historic_messages = self._organize_historic_prompt_messages( + [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] + ) messages = [ system_message, *historic_messages, *query_messages, *assistant_messages, - UserPromptMessage(content='continue') + UserPromptMessage(content="continue"), ] else: # organize historic prompt messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 9e6eb54f4f..9dab956f9a 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner): prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) - + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + return system_prompt def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: @@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner): # organize current assistant messages agent_scratchpad = self._agent_scratchpad - assistant_prompt = '' + assistant_prompt = "" for unit in agent_scratchpad: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" @@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner): query_prompt = f"Question: {self._query}" # join all messages - prompt = system_prompt \ - .replace("{{historic_messages}}", historic_prompt) \ - .replace("{{agent_scratchpad}}", assistant_prompt) \ + prompt = ( + system_prompt.replace("{{historic_messages}}", historic_prompt) + .replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{query}}", query_prompt) + ) - return [UserPromptMessage(content=prompt)] \ No newline at end of file + return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 776e40e9fe..b51a163549 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -20,6 +20,7 @@ class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. """ + first_prompt: str next_iteration: str @@ -33,6 +34,7 @@ class AgentScratchpadUnit(BaseModel): """ Action Entity. """ + action_name: str action_input: Union[dict, str] @@ -41,8 +43,8 @@ class AgentScratchpadUnit(BaseModel): Convert to dictionary. """ return { - 'action': self.action_name, - 'action_input': self.action_input, + "action": self.action_name, + "action_input": self.action_input, } agent_response: Optional[str] = None @@ -56,10 +58,10 @@ class AgentScratchpadUnit(BaseModel): Check if the scratchpad unit is final. """ return self.action is None or ( - 'final' in self.action.action_name.lower() and - 'answer' in self.action.action_name.lower() + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() ) + class AgentEntity(BaseModel): """ Agent Entity. @@ -69,8 +71,9 @@ class AgentEntity(BaseModel): """ Agent Strategy. """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" provider: str model: str diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3ee6e47742..13164e0bfc 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -24,11 +24,9 @@ from models.model import Message logger = logging.getLogger(__name__) -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, - message: Message, query: str, **kwargs: Any - ) -> Generator[LLMResultChunk, None, None]: +class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ @@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" # get tracing instance trace_manager = app_generate_entity.trace_manager - + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) # recalc llm max tokens @@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response - response = '' + response = "" # save tool call names and inputs - tool_call_names = '' - tool_call_inputs = '' + tool_call_names = "" + tool_call_inputs = "" current_llm_usage = None @@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True tool_calls.extend(self.extract_tool_calls(chunk)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: if isinstance(chunk.delta.message.content, list): @@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.check_blocking_tool_calls(result): function_call_state = True tool_calls.extend(self.extract_blocking_tool_calls(result)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: increase_usage(llm_usage, result.usage) @@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): response += result.message.content if not result.message.content: - result.message.content = '' + result.message.content = "" + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - yield LLMResultChunk( model=model_instance.model, prompt_messages=result.prompt_messages, @@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner): index=0, message=result.message, usage=result.usage, - ) + ), ) - assistant_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: - assistant_message.tool_calls=[ + assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=tool_call[0], - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], - arguments=json.dumps(tool_call[2], ensure_ascii=False) - ) - ) for tool_call in tool_calls + name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) + ), + ) + for tool_call in tool_calls ] else: assistant_message.content = response - + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): observation=None, answer=response, messages_ids=[], - llm_usage=current_llm_usage + llm_usage=current_llm_usage, ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - - final_answer += response + '\n' + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + final_answer += response + "\n" # call tools tool_responses = [] @@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), } else: # invoke tool @@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) - + tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict() + "meta": tool_invoke_meta.to_dict(), } - + tool_responses.append(tool_response) - if tool_response['tool_response'] is not None: + if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response['tool_response'], + content=tool_response["tool_response"], tool_call_id=tool_call_id, name=tool_call_name, ) - ) + ) if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=None, tool_input=None, - thought=None, + thought=None, tool_invoke_meta={ - tool_response['tool_call_name']: tool_response['meta'] - for tool_response in tool_responses + tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, observation={ - tool_response['tool_call_name']: tool_response['tool_response'] + tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, answer=None, - messages_ids=message_file_ids + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ @@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if llm_result_chunk.delta.message.tool_calls: return True return False - + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: """ Check if there is any blocking tool call in llm result @@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_tool_calls( + self, llm_result_chunk: LLMResultChunk + ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls = [] for prompt_message in llm_result_chunk.delta.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result @@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls = [] for prompt_message in llm_result.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _init_system_message( + self, prompt_template: str, prompt_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ Initialize system message """ @@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): return [ SystemPromptMessage(content=prompt_template), ] - + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) return prompt_messages - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages - + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner): for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - content.data if content.type == PromptMessageContentType.TEXT else - '[image]' if content.type == PromptMessageContentType.IMAGE else - '[file]' - for content in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) return prompt_messages def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) query_prompt_messages = self._organize_user_query(self.query, []) @@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): model_config=self.model_config, prompt_messages=[*query_prompt_messages, *self._current_thoughts], history_messages=self.history_prompt_messages, - memory=self.memory + memory=self.memory, ).get_prompt() - prompt_messages = [ - *self.history_prompt_messages, - *query_prompt_messages, - *self._current_thoughts - ] + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] if len(self._current_thoughts) != 0: # clear messages after the first iteration prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c53fa5000e..d04e38777a 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod - def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ - Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def handle_react_stream_output( + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: action = json.loads(json_str) @@ -22,7 +23,7 @@ class CotAgentOutputParser: action = action[0] for key, value in action.items(): - if 'input' in key.lower(): + if "input" in key.lower(): action_input = value else: action_name = value @@ -33,37 +34,37 @@ class CotAgentOutputParser: action_input=action_input, ) else: - return json_str or '' + return json_str or "" except: - return json_str or '' - + return json_str or "" + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) yield parse_action(json_text) - - code_block_cache = '' + + code_block_cache = "" code_block_delimiter_count = 0 in_code_block = False - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False got_json = False - action_cache = '' - action_str = 'action:' + action_cache = "" + action_str = "action:" action_idx = 0 - thought_cache = '' - thought_str = 'thought:' + thought_cache = "" + thought_str = "thought:" thought_idx = 0 for response in llm_response: if response.delta.usage: - usage_dict['usage'] = response.delta.usage + usage_dict["usage"] = response.delta.usage response = response.delta.message.content if not isinstance(response, str): continue @@ -72,24 +73,24 @@ class CotAgentOutputParser: index = 0 while index < len(response): steps = 1 - delta = response[index:index+steps] - last_character = response[index-1] if index > 0 else '' + delta = response[index : index + steps] + last_character = response[index - 1] if index > 0 else "" - if delta == '`': + if delta == "`": code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: yield code_block_cache - code_block_cache = '' + code_block_cache = "" else: code_block_cache += delta code_block_delimiter_count = 0 if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue @@ -97,7 +98,7 @@ class CotAgentOutputParser: action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue @@ -105,18 +106,18 @@ class CotAgentOutputParser: action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue else: if action_cache: yield action_cache - action_cache = '' + action_cache = "" action_idx = 0 - + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue @@ -124,7 +125,7 @@ class CotAgentOutputParser: thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue @@ -132,31 +133,31 @@ class CotAgentOutputParser: thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue else: if thought_cache: yield thought_cache - thought_cache = '' + thought_cache = "" thought_idx = 0 if code_block_delimiter_count == 3: if in_code_block: yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - + code_block_cache = "" + in_code_block = not in_code_block code_block_delimiter_count = 0 if not in_code_block: # handle single json - if delta == '{': + if delta == "{": json_quote_count += 1 in_json = True json_cache += delta - elif delta == '}': + elif delta == "}": json_cache += delta if json_quote_count > 0: json_quote_count -= 1 @@ -172,12 +173,12 @@ class CotAgentOutputParser: if got_json: got_json = False yield parse_action(json_cache) - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False - + if not in_code_block and not in_json: - yield delta.replace('`', '') + yield delta.replace("`", "") index += steps @@ -186,4 +187,3 @@ class CotAgentOutputParser: if json_cache: yield parse_action(json_cache) - diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index b0cf1a77fb..ef64fd29fc 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use {{historic_messages}} Question: {{query}} {{agent_scratchpad}} -Thought:""" +Thought:""" # noqa: E501 + ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} Thought:""" @@ -86,19 +87,20 @@ Action: ``` Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -""" +""" # noqa: E501 + ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" REACT_PROMPT_TEMPLATES = { - 'english': { - 'chat': { - 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES + "english": { + "chat": { + "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, + }, + "completion": { + "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES, }, - 'completion': { - 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES - } } -} \ No newline at end of file +} diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 3dea305e98..24d80f9cdd 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -26,34 +26,24 @@ class BaseAppConfigManager: config_dict = dict(config_dict.items()) additional_features = AppAdditionalFeatures() - additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( - config=config_dict - ) + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, - is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} ) - additional_features.opening_statement, additional_features.suggested_questions = \ - OpeningStatementConfigManager.convert( - config=config_dict - ) + additional_features.opening_statement, additional_features.suggested_questions = ( + OpeningStatementConfigManager.convert(config=config_dict) + ) additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( config=config_dict ) - additional_features.more_like_this = MoreLikeThisConfigManager.convert( - config=config_dict - ) + additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) - additional_features.speech_to_text = SpeechToTextConfigManager.convert( - config=config_dict - ) + additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) - additional_features.text_to_speech = TextToSpeechConfigManager.convert( - config=config_dict - ) + additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) return additional_features diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 1ca8b1e3b8..037037e6ca 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: - sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None - if sensitive_word_avoidance_dict.get('enabled'): + if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config"), ) else: return None @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ - -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id, config: dict, only_structure_validate: bool = False + ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): - config["sensitive_word_avoidance"] = { - "enabled": False - } + config["sensitive_word_avoidance"] = {"enabled": False} if not isinstance(config["sensitive_word_avoidance"], dict): raise ValueError("sensitive_word_avoidance must be of dict type") @@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager: typ = config["sensitive_word_avoidance"]["type"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config - ) + ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index dc65d4439b..f503543d7b 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -12,67 +12,70 @@ class AgentConfigManager: :param config: model config args """ - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode']: + if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: + agent_dict = config.get("agent_mode", {}) + agent_strategy = agent_dict.get("strategy", "cot") - agent_dict = config.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': + if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': + elif agent_strategy in {"cot", "react"}: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy - if config['model']['provider'] == 'openai': + if config["model"]["provider"] == "openai": strategy = AgentEntity.Strategy.FUNCTION_CALLING else: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) >= 4: if "enabled" not in tool or not tool["enabled"]: continue agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool.get('tool_parameters', {}) + "provider_type": tool["provider_type"], + "provider_id": tool["provider_id"], + "tool_name": tool["tool_name"], + "tool_parameters": tool.get("tool_parameters", {}), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if 'strategy' in config['agent_mode'] and \ - config['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { + "react_router", + "router", + }: + agent_prompt = agent_dict.get("prompt", None) or {} # check model mode - model_mode = config.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': + model_mode = config.get("model", {}).get("mode", "completion") + if model_mode == "completion": agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] + ), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"] + ), ) return AgentEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) + max_iteration=agent_dict.get("max_iteration", 5), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 1a621d2090..a22395b8e3 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -15,39 +15,38 @@ class DatasetConfigManager: :param config: model config args """ dataset_ids = [] - if 'datasets' in config.get('dataset_configs', {}): - datasets = config.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) + if "datasets" in config.get("dataset_configs", {}): + datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) - for dataset in datasets.get('datasets', []): + for dataset in datasets.get("datasets", []): keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': + if len(keys) == 0 or keys[0] != "dataset": continue - dataset = dataset['dataset'] + dataset = dataset["dataset"] - if 'enabled' not in dataset or not dataset['enabled']: + if "enabled" not in dataset or not dataset["enabled"]: continue - dataset_id = dataset.get('id', None) + dataset_id = dataset.get("id", None) if dataset_id: dataset_ids.append(dataset_id) - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode'] \ - and config['agent_mode']['enabled']: + if ( + "agent_mode" in config + and config["agent_mode"] + and "enabled" in config["agent_mode"] + and config["agent_mode"]["enabled"] + ): + agent_dict = config.get("agent_mode", {}) - agent_dict = config.get('agent_mode', {}) - - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) == 1: # old standard key = list(tool.keys())[0] - if key != 'dataset': + if key != "dataset": continue tool_item = tool[key] @@ -55,30 +54,28 @@ class DatasetConfigManager: if "enabled" not in tool_item or not tool_item["enabled"]: continue - dataset_id = tool_item['id'] + dataset_id = tool_item["id"] dataset_ids.append(dataset_id) if len(dataset_ids) == 0: return None # dataset configs - if 'dataset_configs' in config and config.get('dataset_configs'): - dataset_configs = config.get('dataset_configs') + if "dataset_configs" in config and config.get("dataset_configs"): + dataset_configs = config.get("dataset_configs") else: - dataset_configs = { - 'retrieval_model': 'multiple' - } - query_variable = config.get('dataset_query_variable') + dataset_configs = {"retrieval_model": "multiple"} + query_variable = config.get("dataset_query_variable") - if dataset_configs['retrieval_model'] == 'single': + if dataset_configs["retrieval_model"] == "single": return DatasetEntity( dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) + dataset_configs["retrieval_model"] + ), + ), ) else: return DatasetEntity( @@ -86,15 +83,15 @@ class DatasetConfigManager: retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] + dataset_configs["retrieval_model"] ), - top_k=dataset_configs.get('top_k', 4), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights'), - reranking_enabled=dataset_configs.get('reranking_enabled', True), - rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), - ) + top_k=dataset_configs.get("top_k", 4), + score_threshold=dataset_configs.get("score_threshold"), + reranking_model=dataset_configs.get("reranking_model"), + weights=dataset_configs.get("weights"), + reranking_enabled=dataset_configs.get("reranking_enabled", True), + rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + ), ) @classmethod @@ -111,13 +108,10 @@ class DatasetConfigManager: # dataset_configs if not config.get("dataset_configs"): - config["dataset_configs"] = {'retrieval_model': 'single'} + config["dataset_configs"] = {"retrieval_model": "single"} if not config["dataset_configs"].get("datasets"): - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") @@ -125,8 +119,9 @@ class DatasetConfigManager: if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = (config.get("dataset_configs") - and config["dataset_configs"].get("datasets", {}).get("datasets")) + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( + "datasets", {} + ).get("datasets") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion @@ -148,10 +143,7 @@ class DatasetConfigManager: """ # Extract dataset config for legacy compatibility if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -175,7 +167,7 @@ class DatasetConfigManager: config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value has_datasets = False - if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: for tool in config["agent_mode"]["tools"]: key = list(tool.keys())[0] if key == "dataset": @@ -188,7 +180,7 @@ class DatasetConfigManager: if not isinstance(tool_item["enabled"], bool): raise ValueError("enabled in agent_mode.tools must be of boolean type") - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5c9b2cfec7..a91b9f0f02 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, - skip_check: bool = False) \ - -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -25,9 +23,7 @@ class ModelConfigConverter: provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=app_config.tenant_id, - provider=model_config.provider, - model_type=ModelType.LLM + tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) provider_name = provider_model_bundle.configuration.provider.provider @@ -38,8 +34,7 @@ class ModelConfigConverter: # check model credentials model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model_config.model + model_type=ModelType.LLM, model=model_config.model ) if model_credentials is None: @@ -51,8 +46,7 @@ class ModelConfigConverter: if not skip_check: # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, - model_type=ModelType.LLM + model=model_config.model, model_type=ModelType.LLM ) if provider_model is None: @@ -69,24 +63,18 @@ class ModelConfigConverter: # model config completion_params = model_config.parameters stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = model_config.mode if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=model_config.model, - credentials=model_credentials - ) + mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) model_mode = mode_enum.value - model_schema = model_type_instance.get_model_schema( - model_config.model, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 730a9527cf..b5e4554181 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -13,23 +13,23 @@ class ModelConfigManager: :param config: model config args """ # model config - model_config = config.get('model') + model_config = config.get("model") if not model_config: raise ValueError("model is required") - completion_params = model_config.get('completion_params') + completion_params = model_config.get("completion_params") stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode - model_mode = model_config.get('mode') + model_mode = model_config.get("mode") return ModelConfigEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], mode=model_mode, parameters=completion_params, stop=stop, @@ -43,7 +43,7 @@ class ModelConfigManager: :param tenant_id: tenant id :param config: app model config args """ - if 'model' not in config: + if "model" not in config: raise ValueError("model is required") if not isinstance(config["model"], dict): @@ -52,17 +52,16 @@ class ModelConfigManager: # model.provider provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name - if 'name' not in config["model"]: + if "name" not in config["model"]: raise ValueError("model.name is required") provider_manager = ProviderManager() models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM + provider=config["model"]["provider"], model_type=ModelType.LLM ) if not models: @@ -80,12 +79,12 @@ class ModelConfigManager: # model.mode if model_mode: - config['model']["mode"] = model_mode + config["model"]["mode"] = model_mode else: - config['model']["mode"] = "completion" + config["model"]["mode"] = "completion" # model.completion_params - if 'completion_params' not in config["model"]: + if "completion_params" not in config["model"]: raise ValueError("model.completion_params is required") config["model"]["completion_params"] = cls.validate_model_completion_params( @@ -101,7 +100,7 @@ class ModelConfigManager: raise ValueError("model.completion_params must be of object type") # stop - if 'stop' not in cp: + if "stop" not in cp: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 1f410758aa..82a0e56ce8 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -14,39 +14,33 @@ class PromptTemplateConfigManager: if not config.get("prompt_type"): raise ValueError("prompt_type is required") - prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: simple_prompt_template = config.get("pre_prompt", "") - return PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) + return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) else: advanced_chat_prompt_template = None chat_prompt_config = config.get("chat_prompt_config", {}) if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) + chat_prompt_messages.append( + {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], + "prompt": completion_prompt_config["prompt"]["text"], } - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + if "conversation_histories_role" in completion_prompt_config: + completion_prompt_template_params["role_prefix"] = { + "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], + "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( @@ -56,7 +50,7 @@ class PromptTemplateConfigManager: return PromptTemplateEntity( prompt_type=prompt_type, advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template + advanced_completion_prompt_template=advanced_completion_prompt_template, ) @classmethod @@ -72,7 +66,7 @@ class PromptTemplateConfigManager: config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] - if config['prompt_type'] not in prompt_type_vals: + if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") # chat_prompt_config @@ -89,27 +83,28 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required " - "when prompt_type is advanced") + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config["chat_prompt_config"] and not config["completion_prompt_config"]: + raise ValueError( + "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" + ) model_mode_vals = [mode.value for mode in ModelMode] - if config['model']["mode"] not in model_mode_vals: + if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] + assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config['model']["mode"] == ModelMode.CHAT.value: - prompt_list = config['chat_prompt_config']['prompt'] + if config["model"]["mode"] == ModelMode.CHAT.value: + prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: raise ValueError("prompt messages must be less than 10") diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 15fa4d99fd..a1bfde3208 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -16,51 +16,49 @@ class BasicVariablesConfigManager: variable_entities = [] # old external_data_tools - external_data_tools = config.get('external_data_tools', []) + external_data_tools = config.get("external_data_tools", []) for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + if "enabled" not in external_data_tool or not external_data_tool["enabled"]: continue external_data_variables.append( ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] + variable=external_data_tool["variable"], + type=external_data_tool["type"], + config=external_data_tool["config"], ) ) # variables and external_data_tools - for variables in config.get('user_input_form', []): + for variables in config.get("user_input_form", []): variable_type = list(variables.keys())[0] if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: variable = variables[variable_type] - if 'config' not in variable: + if "config" not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=variable['variable'], - type=variable['type'], - config=variable['config'] + variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) - elif variable_type in [ + elif variable_type in { VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH, VariableEntityType.NUMBER, VariableEntityType.SELECT, - ]: + }: variable = variables[variable_type] variable_entities.append( VariableEntity( type=variable_type, - variable=variable.get('variable'), - description=variable.get('description'), - label=variable.get('label'), - required=variable.get('required', False), - max_length=variable.get('max_length'), - options=variable.get('options'), - default=variable.get('default'), + variable=variable.get("variable"), + description=variable.get("description"), + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options"), + default=variable.get("default"), ) ) @@ -99,17 +97,17 @@ class BasicVariablesConfigManager: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] - if 'label' not in form_item: + if "label" not in form_item: raise ValueError("label is required in user_input_form") if not isinstance(form_item["label"], str): raise ValueError("label in user_input_form must be of string type") - if 'variable' not in form_item: + if "variable" not in form_item: raise ValueError("variable is required in user_input_form") if not isinstance(form_item["variable"], str): @@ -117,26 +115,24 @@ class BasicVariablesConfigManager: pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") + raise ValueError("variable in user_input_form must be a string, and cannot start with a number") variables.append(form_item["variable"]) - if 'required' not in form_item or not form_item["required"]: + if "required" not in form_item or not form_item["required"]: form_item["required"] = False if not isinstance(form_item["required"], bool): raise ValueError("required in user_input_form must be of boolean type") if key == "select": - if 'options' not in form_item or not form_item["options"]: + if "options" not in form_item or not form_item["options"]: form_item["options"] = [] if not isinstance(form_item["options"], list): raise ValueError("options in user_input_form must be a list of strings") - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: + if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: raise ValueError("default value in user_input_form must be in the options list") return config, ["user_input_form"] @@ -168,10 +164,6 @@ class BasicVariablesConfigManager: typ = tool["type"] config = tool["config"] - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index bbb10d3d76..7e5899bafa 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel): """ Model Config Entity. """ + provider: str model: str mode: Optional[str] = None @@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel): """ Advanced Chat Message Entity. """ + text: str role: PromptMessageRole @@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel): """ Advanced Chat Prompt Template Entity. """ + messages: list[AdvancedChatMessageEntity] @@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): """ Role Prefix Entity. """ + user: str assistant: str @@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel): Prompt Type. 'simple', 'advanced' """ - SIMPLE = 'simple' - ADVANCED = 'advanced' + + SIMPLE = "simple" + ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> 'PromptType': + def value_of(cls, value: str) -> "PromptType": """ Get value of given mode. @@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt type value {value}') + raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType simple_prompt_template: Optional[str] = None @@ -87,7 +92,7 @@ class VariableEntityType(str, Enum): SELECT = "select" PARAGRAPH = "paragraph" NUMBER = "number" - EXTERNAL_DATA_TOOL = "external-data-tool" + EXTERNAL_DATA_TOOL = "external_data_tool" class VariableEntity(BaseModel): @@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. """ + variable: str type: str config: dict[str, Any] = {} @@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = 'single' - MULTIPLE = 'multiple' + + SINGLE = "single" + MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> 'RetrieveStrategy': + def value_of(cls, value: str) -> "RetrieveStrategy": """ Get value of given mode. @@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid retrieve strategy value {value}') + raise ValueError(f"invalid retrieve strategy value {value}") query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = .0 - rerank_mode: Optional[str] = 'reranking_model' + score_threshold: Optional[float] = 0.0 + rerank_mode: Optional[str] = "reranking_model" reranking_model: Optional[dict] = None weights: Optional[dict] = None reranking_enabled: Optional[bool] = True - - class DatasetEntity(BaseModel): """ Dataset Config Entity. """ + dataset_ids: list[str] retrieve_config: DatasetRetrieveConfigEntity @@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + type: str config: dict[str, Any] = {} @@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + enabled: bool voice: Optional[str] = None language: Optional[str] = None @@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel): """ Tracing Config Entity. """ + enabled: bool tracing_provider: str - - class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileExtraConfig] = None opening_statement: Optional[str] = None @@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel): text_to_speech: Optional[TextToSpeechEntity] = None trace_config: Optional[TracingConfigEntity] = None + class AppConfig(BaseModel): """ Application Config Entity. """ + tenant_id: str app_id: str app_mode: AppMode @@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum): """ App Model Config From. """ - ARGS = 'args' - APP_LATEST_CONFIG = 'app-latest-config' - CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' + + ARGS = "args" + APP_LATEST_CONFIG = "app-latest-config" + CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" class EasyUIBasedAppConfig(AppConfig): """ Easy UI Based App Config Entity. """ + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str app_model_config_dict: dict @@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ + workflow_id: str diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 3da3c2eddb..7a275cb532 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -13,21 +13,19 @@ class FileUploadConfigManager: :param config: model config args :param is_vision: if True, the feature is vision feature """ - file_upload_dict = config.get('file_upload') + file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get('image'): - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + if file_upload_dict.get("image"): + if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: image_config = { - 'number_limits': file_upload_dict['image']['number_limits'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] + "number_limits": file_upload_dict["image"]["number_limits"], + "transfer_methods": file_upload_dict["image"]["transfer_methods"], } if is_vision: - image_config['detail'] = file_upload_dict['image']['detail'] + image_config["detail"] = file_upload_dict["image"]["detail"] - return FileExtraConfig( - image_config=image_config - ) + return FileExtraConfig(image_config=image_config) return None @@ -49,21 +47,21 @@ class FileUploadConfigManager: if not config["file_upload"].get("image"): config["file_upload"]["image"] = {"enabled": False} - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] + if config["file_upload"]["image"]["enabled"]: + number_limits = config["file_upload"]["image"]["number_limits"] if number_limits < 1 or number_limits > 6: raise ValueError("number_limits must be in [1, 6]") if is_vision: - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: + detail = config["file_upload"]["image"]["detail"] + if detail not in {"high", "low"}: raise ValueError("detail must be in ['high', 'low']") - transfer_methods = config['file_upload']['image']['transfer_methods'] + transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ['remote_url', 'local_file']: + if method not in {"remote_url", "local_file"}: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index 2ba99a5c40..496e1beeec 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -7,9 +7,9 @@ class MoreLikeThisConfigManager: :param config: model config args """ more_like_this = False - more_like_this_dict = config.get('more_like_this') + more_like_this_dict = config.get("more_like_this") if more_like_this_dict: - if more_like_this_dict.get('enabled'): + if more_like_this_dict.get("enabled"): more_like_this = True return more_like_this @@ -22,9 +22,7 @@ class MoreLikeThisConfigManager: :param config: app model config args """ if not config.get("more_like_this"): - config["more_like_this"] = { - "enabled": False - } + config["more_like_this"] = {"enabled": False} if not isinstance(config["more_like_this"], dict): raise ValueError("more_like_this must be of dict type") diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 0d8a71bfcf..b4dacbc409 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,5 +1,3 @@ - - class OpeningStatementConfigManager: @classmethod def convert(cls, config: dict) -> tuple[str, list]: @@ -9,10 +7,10 @@ class OpeningStatementConfigManager: :param config: model config args """ # opening statement - opening_statement = config.get('opening_statement') + opening_statement = config.get("opening_statement") # suggested questions - suggested_questions_list = config.get('suggested_questions') + suggested_questions_list = config.get("suggested_questions") return opening_statement, suggested_questions_list diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index fca58e12e8..d098abac2f 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -2,9 +2,9 @@ class RetrievalResourceConfigManager: @classmethod def convert(cls, config: dict) -> bool: show_retrieve_source = False - retriever_resource_dict = config.get('retriever_resource') + retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: - if retriever_resource_dict.get('enabled'): + if retriever_resource_dict.get("enabled"): show_retrieve_source = True return show_retrieve_source @@ -17,9 +17,7 @@ class RetrievalResourceConfigManager: :param config: app model config args """ if not config.get("retriever_resource"): - config["retriever_resource"] = { - "enabled": False - } + config["retriever_resource"] = {"enabled": False} if not isinstance(config["retriever_resource"], dict): raise ValueError("retriever_resource must be of dict type") diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index 88b4be25d3..e10ae03e04 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -7,9 +7,9 @@ class SpeechToTextConfigManager: :param config: model config args """ speech_to_text = False - speech_to_text_dict = config.get('speech_to_text') + speech_to_text_dict = config.get("speech_to_text") if speech_to_text_dict: - if speech_to_text_dict.get('enabled'): + if speech_to_text_dict.get("enabled"): speech_to_text = True return speech_to_text @@ -22,9 +22,7 @@ class SpeechToTextConfigManager: :param config: app model config args """ if not config.get("speech_to_text"): - config["speech_to_text"] = { - "enabled": False - } + config["speech_to_text"] = {"enabled": False} if not isinstance(config["speech_to_text"], dict): raise ValueError("speech_to_text must be of dict type") diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index c6cab01220..9ac5114d12 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager: :param config: model config args """ suggested_questions_after_answer = False - suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") if suggested_questions_after_answer_dict: - if suggested_questions_after_answer_dict.get('enabled'): + if suggested_questions_after_answer_dict.get("enabled"): suggested_questions_after_answer = True return suggested_questions_after_answer @@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager: :param config: app model config args """ if not config.get("suggested_questions_after_answer"): - config["suggested_questions_after_answer"] = { - "enabled": False - } + config["suggested_questions_after_answer"] = {"enabled": False} if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not \ - config["suggested_questions_after_answer"]["enabled"]: + if ( + "enabled" not in config["suggested_questions_after_answer"] + or not config["suggested_questions_after_answer"]["enabled"] + ): config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index f11e268e73..1c75981785 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -10,13 +10,13 @@ class TextToSpeechConfigManager: :param config: model config args """ text_to_speech = None - text_to_speech_dict = config.get('text_to_speech') + text_to_speech_dict = config.get("text_to_speech") if text_to_speech_dict: - if text_to_speech_dict.get('enabled'): + if text_to_speech_dict.get("enabled"): text_to_speech = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), + enabled=text_to_speech_dict.get("enabled"), + voice=text_to_speech_dict.get("voice"), + language=text_to_speech_dict.get("language"), ) return text_to_speech @@ -29,11 +29,7 @@ class TextToSpeechConfigManager: :param config: app model config args """ if not config.get("text_to_speech"): - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } + config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} if not isinstance(config["text_to_speech"], dict): raise ValueError("text_to_speech must be of dict type") diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index c3d0e8ba03..b52f235849 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): """ Advanced Chatbot App Config Entity. """ + pass class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_mode = AppMode.value_of(app_model.mode) @@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) @@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): filtered_config = {key: config.get(key) for key in related_config_keys} return filtered_config - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 9ae86a806f..f421ec2211 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom @@ -34,7 +34,8 @@ logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -54,7 +56,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -63,14 +66,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) -> Union[dict[str, Any], Generator[dict | str, None, None]]: ... def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: bool = True, - ) -> dict[str, Any] | Generator[str | dict, None, None]: + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True, + ) -> dict[str, Any] | Generator[str | dict, None, None]: """ Generate App response. @@ -81,44 +84,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', False) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} # get conversation conversation = None - conversation_id = args.get('conversation_id') + conversation_id = args.get("conversation_id") if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) + conversation = self._get_conversation_by_user( + app_model=app_model, conversation_id=conversation_id, user=user + ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -140,7 +136,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -150,16 +146,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, application_generate_entity=application_generate_entity, conversation=conversation, - stream=stream + stream=stream, ) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account | EndUser, - args: dict, - stream: bool = True) \ - -> dict[str, Any] | Generator[str, Any, None]: + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account | EndUser, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -171,16 +163,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( @@ -188,18 +177,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_config=app_config, conversation_id=None, inputs={}, - query='', + query="", files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras={ - "auto_generate_conversation_name": False - }, + extras={"auto_generate_conversation_name": False}, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -209,17 +195,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, conversation=None, - stream=stream + stream=stream, ) - def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Optional[Conversation] = None, - stream: bool = True) \ - -> dict[str, Any] | Generator[str, Any, None]: + def _generate( + self, + *, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -235,10 +223,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): is_first_conversation = True # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -253,18 +238,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - 'context': contextvars.copy_context(), - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": contextvars.copy_context(), + }, + ) worker_thread.start() @@ -278,18 +266,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, stream=stream, ) - - return AdvancedChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str, - context: contextvars.Context) -> None: + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + context: contextvars.Context, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -312,22 +300,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) runner.run() - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -373,7 +360,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 0caff4a2e3..18b115dfe4 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -21,14 +21,11 @@ class AudioTrunk: self.status = status -def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( - content_text=text_content.strip(), - user="responding_tts", - tenant_id=tenant_id, - voice=voice + content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice ) @@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue): except Exception as e: logging.getLogger(__name__).warning(e) break - audio_queue.put(AudioTrunk("finish", b'')) + audio_queue.put(AudioTrunk("finish", b"")) class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id - self.msg_text = '' + self.msg_text = "" self._audio_queue = queue.Queue() self._msg_queue = queue.Queue() - self.match = re.compile(r'[。.!?]') + self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.TTS + tenant_id=self.tenant_id, model_type=ModelType.TTS ) self.voices = self.model_instance.get_tts_voices() - values = [voice.get('value') for voice in self.voices] + values = [voice.get("value") for voice in self.voices] self.voice = voice if not voice or voice not in values: - self.voice = self.voices[0].get('value') + self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 self._last_audio_event = None self._runtime_thread = threading.Thread(target=self._runtime).start() @@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher: message = self._msg_queue.get() if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: - futures_result = self.executor.submit(_invoiceTTS, self.msg_text, - self.model_instance, self.tenant_id, self.voice) + futures_result = self.executor.submit( + _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): @@ -94,28 +90,27 @@ class AppGeneratorTTSPublisher: elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): - self.msg_text += message.event.outputs.get('output', '') + self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): self.MAX_SENTENCE += 1 - text_content = ''.join(sentence_arr) - futures_result = self.executor.submit(_invoiceTTS, text_content, - self.model_instance, - self.tenant_id, - self.voice) + text_content = "".join(sentence_arr) + futures_result = self.executor.submit( + _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) if text_tmp: self.msg_text = text_tmp else: - self.msg_text = '' + self.msg_text = "" except Exception as e: self.logger.warning(e) break future_queue.put(None) - def checkAndGetAudio(self) -> AudioTrunk | None: + def check_and_get_audio(self) -> AudioTrunk | None: try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 4da3d093d2..1bca1e1b71 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,7 +19,7 @@ from core.app.entities.queue_entities import ( QueueStopEvent, QueueTextChunkEvent, ) -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool @@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ def __init__( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, ) -> None: """ :param application_generate_entity: application generate entity @@ -66,14 +66,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id = self.application_generate_entity.user_id workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) if self.application_generate_entity.single_iteration_run: @@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs + user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) else: inputs = self.application_generate_entity.inputs @@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # moderation if self.handle_input_moderation( - app_record=app_record, - app_generate_entity=self.application_generate_entity, - inputs=inputs, - query=query, - message_id=self.message.id + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id, ): return # annotation reply if self.handle_annotation_reply( - app_record=app_record, - message=self.message, - query=query, - app_generate_entity=self.application_generate_entity + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity, ): return # Init conversation variables stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: conversation_variables = session.scalars(stmt).all() @@ -174,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, @@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._handle_event(workflow_entry, event) def handle_input_moderation( - self, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str + self, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> bool: """ Handle input moderation @@ -216,19 +217,15 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query=query, message_id=message_id, ) - except ModerationException as e: - self._complete_with_stream_output( - text=str(e), - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION - ) + except ModerationError as e: + self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True return False - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + def handle_annotation_reply( + self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity + ) -> bool: """ Handle annotation reply :param app_record: app record @@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) if annotation_reply: - self._publish_event( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) - ) + self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) self._complete_with_stream_output( - text=annotation_reply.content, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _complete_with_stream_output(self, - text: str, - stopped_by: QueueStopEvent.StopBy) -> None: + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output :param text: text :return: """ - self._publish_event( - QueueTextChunkEvent( - text=text - ) - ) + self._publish_event(QueueTextChunkEvent(text=text)) - self._publish_event( - QueueStopEvent(stopped_by=stopped_by) - ) + self._publish_event(QueueStopEvent(stopped_by=stopped_by)) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 2ddbd816e2..b2bff43208 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -27,15 +27,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -49,13 +49,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[dict | str, Any, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,14 +68,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -84,7 +86,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield response_chunk @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[dict | str, Any, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, Any, None]: """ Convert stream simple response. :param stream_response: stream response @@ -95,20 +99,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fb013cd1b1..94206a1b1c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow @@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow_system_variables: dict[SystemVariableKey, Any] def __init__( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: - extras['metadata'] = stream_response.metadata + extras["metadata"] = stream_response.metadata return ChatbotAppBlockingResponse( task_id=stream_response.task_id, @@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc message_id=self._message.id, answer=self._task_state.answer, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[ChatbotAppStreamResponse, Any, None]: """ To stream response. :return: @@ -176,32 +176,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -214,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc try: if not tts_publisher: break - audio_trunk = tts_publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc db.session.close() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start( - workflow_run=workflow_run, - event=event - ) + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, @@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) yield self._error_to_stream_response(self._handle_error(err_event, self._message)) break elif isinstance(event, QueueStopEvent): @@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) # Save message @@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._refetch_message() - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() db.session.refresh(self._message) @@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._refetch_message() - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() db.session.refresh(self._message) @@ -466,13 +451,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response( + answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) if output_moderation_answer: @@ -502,8 +489,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage @@ -523,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -533,15 +521,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata.copy() + extras["metadata"] = self._task_state.metadata.copy() - if 'annotation_reply' in extras['metadata']: - del extras['metadata']['annotation_reply'] + if "annotation_reply" in extras["metadata"]: + del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _handle_output_moderation_chunk(self, text: str) -> bool: @@ -555,14 +541,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index f495ebbf35..9040f18bfd 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): """ Agent Chatbot App Config Entity. """ + agent: Optional[AgentEntity] = None class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model @@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - agent=AgentConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + agent=AgentConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # dataset configs # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): :param config: app model config args """ if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config["agent_mode"].get("strategy"): config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [member.value for member in - list(PlanningStrategy.__members__.values())]: + if config["agent_mode"]["strategy"] not in [ + member.value for member in list(PlanningStrategy.__members__.values()) + ]: raise ValueError("strategy in agent_mode must be in the specified strategy list") if not config["agent_mode"].get("tools"): @@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): raise ValueError("enabled in agent_mode.tools must be of boolean type") if key == "dataset": - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 726e7ca65c..578eff5053 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom @@ -30,7 +30,8 @@ logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -39,7 +40,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -48,19 +50,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: bool = False, ) -> dict | Generator[dict | str, None, None]: ... - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict | str, None, None]]: + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[dict | str, None, None]]: """ Generate App response. @@ -71,60 +71,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): :param stream: is stream """ if not stream: - raise ValueError('Agent Chat App does not support blocking mode') + raise ValueError("Agent Chat App does not support blocking mode") - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = AgentChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -133,7 +121,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -154,14 +142,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, call_depth=0, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -170,17 +155,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -194,13 +182,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return AgentChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( - self, flask_app: Flask, + self, + flask_app: Flask, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, @@ -229,18 +215,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d1bbf679c5..45b1bf0093 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,7 +15,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db from models.model import App, Conversation, Message, MessageAgentThought @@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner): """ def run( - self, application_generate_entity: AgentChatAppGenerateEntity, + self, + application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner): # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -103,15 +101,15 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner): message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # reorganize all inputs and template to prompt messages @@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: @@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner): agent_entity = app_config.agent # load tool variables - tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, - user_id=application_generate_entity.user_id, - tenant_id=app_config.tenant_id) + tool_conversation_variables = self._load_tool_variables( + conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id + ) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) @@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner): # init model instance model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, @@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner): prompt_messages=prompt_message, variables_pool=tool_variables, db_variables=tool_conversation_variables, - model_instance=model_instance + model_instance=model_instance, ) invoke_result = runner.run( @@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner): invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream, - agent=True + agent=True, ) def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: """ load tool variables from database """ - tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == conversation_id, - ToolConversationVariables.tenant_id == tenant_id - ).first() + tool_variables: ToolConversationVariables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tenant_id, + ) + .first() + ) if tool_variables: # save tool variables to session, so that we can update it later @@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner): conversation_id=conversation_id, user_id=user_id, tenant_id=tenant_id, - variables_str='[]', + variables_str="[]", ) db.session.add(tool_variables) db.session.commit() return tool_variables - - def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: + + def _convert_db_variables_to_tool_variables( + self, db_variables: ToolConversationVariables + ) -> ToolRuntimeVariablePool: """ convert db variables to tool variables """ - return ToolRuntimeVariablePool(**{ - 'conversation_id': db_variables.conversation_id, - 'user_id': db_variables.user_id, - 'tenant_id': db_variables.tenant_id, - 'pool': db_variables.variables - }) + return ToolRuntimeVariablePool( + **{ + "conversation_id": db_variables.conversation_id, + "user_id": db_variables.user_id, + "tenant_id": db_variables.tenant_id, + "pool": db_variables.variables, + } + ) - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, - message: Message) -> LLMUsage: + def _get_usage_of_all_agent_thoughts( + self, model_config: ModelConfigWithCredentialsEntity, message: Message + ) -> LLMUsage: """ Get usage of all agent thoughts :param model_config: model config :param message: message :return: """ - agent_thoughts = (db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == message.id).all()) + agent_thoughts = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() + ) all_message_tokens = 0 all_answer_tokens = 0 @@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner): model_type_instance = cast(LargeLanguageModel, model_type_instance) return model_type_instance._calc_response_usage( - model_config.model, - model_config.credentials, - all_message_tokens, - all_answer_tokens + model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens ) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 02aec27e39..5f294432c9 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -22,15 +22,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -44,8 +44,8 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @@ -62,14 +62,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -92,20 +92,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 77f2ed2eaf..1167b46c83 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -13,11 +13,10 @@ class AppGenerateResponseConverter(ABC): _blocking_response_type: type[AppBlockingResponse] @classmethod - def convert(cls, response: Union[ - AppBlockingResponse, - Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: - if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + def convert( + cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom + ) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]: + if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: @@ -52,8 +51,9 @@ class AppGenerateResponseConverter(ABC): @classmethod @abstractmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @@ -64,24 +64,26 @@ class AppGenerateResponseConverter(ABC): :return: """ # show_retrieve_source - if 'retriever_resources' in metadata: - metadata['retriever_resources'] = [] - for resource in metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) + if "retriever_resources" in metadata: + metadata["retriever_resources"] = [] + for resource in metadata["retriever_resources"]: + metadata["retriever_resources"].append( + { + "segment_id": resource["segment_id"], + "position": resource["position"], + "document_name": resource["document_name"], + "score": resource["score"], + "content": resource["content"], + } + ) # show annotation reply - if 'annotation_reply' in metadata: - del metadata['annotation_reply'] + if "annotation_reply" in metadata: + del metadata["annotation_reply"] # show usage - if 'usage' in metadata: - del metadata['usage'] + if "usage" in metadata: + del metadata["usage"] return metadata @@ -93,16 +95,16 @@ class AppGenerateResponseConverter(ABC): :return: """ error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + ValueError: {"code": "invalid_param", "status": 400}, + ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 + "code": "provider_quota_exceeded", + "message": "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + "status": 400, }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} + ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400}, + InvokeError: {"code": "completion_request_error", "status": 400}, } # Determine the response based on the type of exception @@ -112,13 +114,13 @@ class AppGenerateResponseConverter(ABC): data = v if data: - data.setdefault('message', getattr(e, 'description', str(e))) + data.setdefault("message", getattr(e, "description", str(e))) else: logging.error(e) data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, } return data diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 4159670de4..2da0fd0584 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -17,17 +17,17 @@ class BaseAppGenerator: def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.variable} is required in input form') + raise ValueError(f"{var.variable} is required in input form") if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? - return var.default or '' + return var.default or "" if ( var.type - in ( + in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - ) + } and user_input_value and not isinstance(user_input_value, str) ): @@ -35,7 +35,7 @@ class BaseAppGenerator: if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: - if '.' in user_input_value: + if "." in user_input_value: return float(user_input_value) else: return int(user_input_value) @@ -44,20 +44,20 @@ class BaseAppGenerator: if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.variable} in input form must be one of the following: {options}') - elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + raise ValueError(f"{var.variable} in input form must be one of the following: {options}") + elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') + raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") return user_input_value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): - return value.replace('\x00', '') + return value.replace("\x00", "") return value @classmethod - def convert_to_event_stream(cls, generator: Union[dict, Generator[dict| str, None, None]]): + def convert_to_event_stream(cls, generator: Union[dict, Generator[dict | str, None, None]]): """ Convert messages into event stream """ diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index b45f57e9b6..7cb3387876 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -24,9 +24,7 @@ class PublishFrom(Enum): class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: if not user_id: raise ValueError("user is required") @@ -34,9 +32,10 @@ class AppQueueManager: self._user_id = user_id self._invoke_from = invoke_from - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, - f"{user_prefix}-{self._user_id}") + user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" + redis_client.setex( + AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" + ) q = queue.Queue() @@ -66,8 +65,7 @@ class AppQueueManager: # publish two messages to make sure the client can receive the stop signal # and stop listening after the stop signal processed self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE ) if elapsed_time // 10 > last_ping_time: @@ -88,9 +86,7 @@ class AppQueueManager: :param pub_from: publish from :return: """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) + self.publish(QueueErrorEvent(error=e), pub_from) def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ @@ -122,8 +118,8 @@ class AppQueueManager: if result is None: return - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": + user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" + if result.decode("utf-8") != f"{user_prefix}-{user_id}": return stopped_cache_key = cls._generate_stopped_cache_key(task_id) @@ -168,10 +164,12 @@ class AppQueueManager: for item in data: self._check_for_sqlalchemy_models(item) else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") + if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): + raise TypeError( + "Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed." + ) -class GenerateTaskStoppedException(Exception): +class GenerateTaskStoppedError(Exception): pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 60216959a8..1b412b8639 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -31,12 +31,15 @@ if TYPE_CHECKING: class AppRunner: - def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None) -> int: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + ) -> int: """ Get pre calculate rest tokens :param app_record: app record @@ -49,18 +52,20 @@ class AppRunner: """ # Invoke model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -75,36 +80,39 @@ class AppRunner: prompt_template_entity=prompt_template_entity, inputs=inputs, files=files, - query=query + query=query, ) - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) rest_tokens = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: - raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size.") + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage]): + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -112,27 +120,28 @@ class AppRunner: if max_tokens is None: max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: max_tokens = max(model_context_tokens - prompt_tokens, 16) for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): model_config.parameters[parameter_rule.name] = max_tokens - def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def organize_prompt_messages( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + context: Optional[str] = None, + memory: Optional[TokenBufferMemory] = None, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -152,60 +161,54 @@ class AppRunner: app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, - query=query if query else '', + query=query or "", files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template - prompt_template = CompletionModelPromptTemplate( - text=advanced_completion_prompt_template.prompt - ) + prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: memory_config.role_prefix = MemoryConfig.RolePrefix( user=advanced_completion_prompt_template.role_prefix.user, - assistant=advanced_completion_prompt_template.role_prefix.assistant + assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: - prompt_template.append(ChatModelMessage( - text=message.text, - role=message.role - )) + prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs=inputs, - query=query if query else '', + query=query or "", files=files, context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def direct_output(self, queue_manager: AppQueueManager, - app_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list, - text: str, - stream: bool, - usage: Optional[LLMUsage] = None) -> None: + def direct_output( + self, + queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, + prompt_messages: list, + text: str, + stream: bool, + usage: Optional[LLMUsage] = None, + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -222,17 +225,10 @@ class AppRunner: chunk = LLMResultChunk( model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=token) - ) + delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)), ) - queue_manager.publish( - QueueLLMChunkEvent( - chunk=chunk - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) @@ -242,15 +238,19 @@ class AppRunner: model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() + usage=usage or LLMUsage.empty_usage(), ), - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: AppQueueManager, - stream: bool, - agent: bool = False) -> None: + def _handle_invoke_result( + self, + invoke_result: Union[LLMResult, Generator], + queue_manager: AppQueueManager, + stream: bool, + agent: bool = False, + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -260,21 +260,13 @@ class AppRunner: :return: """ if not stream: - self._handle_invoke_result_direct( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) else: - self._handle_invoke_result_stream( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_direct( + self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result direct :param invoke_result: invoke result @@ -285,12 +277,13 @@ class AppRunner: queue_manager.publish( QueueMessageEndEvent( llm_result=invoke_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_stream( + self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -300,21 +293,13 @@ class AppRunner: """ model = None prompt_messages = [] - text = '' + text = "" usage = None for result in invoke_result: if not agent: - queue_manager.publish( - QueueLLMChunkEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) else: - queue_manager.publish( - QueueAgentMessageEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -331,25 +316,24 @@ class AppRunner: usage = LLMUsage.empty_usage() llm_result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) queue_manager.publish( QueueMessageEndEvent( llm_result=llm_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) def moderation_for_inputs( - self, app_id: str, - tenant_id: str, - app_generate_entity: AppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str, + self, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -367,14 +351,17 @@ class AppRunner: tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query if query else '', + query=query or "", message_id=message_id, - trace_manager=app_generate_entity.trace_manager + trace_manager=app_generate_entity.trace_manager, ) - def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - prompt_messages: list[PromptMessage]) -> bool: + def check_hosting_moderation( + self, + application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, + prompt_messages: list[PromptMessage], + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -384,8 +371,7 @@ class AppRunner: """ hosting_moderation_feature = HostingModerationFeature() moderation_result = hosting_moderation_feature.check( - application_generate_entity=application_generate_entity, - prompt_messages=prompt_messages + application_generate_entity=application_generate_entity, prompt_messages=prompt_messages ) if moderation_result: @@ -393,18 +379,20 @@ class AppRunner: queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ - "but I'm an AI assistant to be helpful, harmless, and honest.", - stream=application_generate_entity.stream + text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream, ) return moderation_result - def fill_in_inputs_from_external_data_tools(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fill_in_inputs_from_external_data_tools( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -417,18 +405,12 @@ class AppRunner: """ external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( - tenant_id=tenant_id, - app_id=app_id, - external_data_tools=external_data_tools, - inputs=inputs, - query=query + tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query ) - def query_app_annotations_to_reply(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -440,9 +422,5 @@ class AppRunner: """ annotation_reply_feature = AnnotationReplyFeature() return annotation_reply_feature.query( - app_record=app_record, - message=message, - query=query, - user_id=user_id, - invoke_from=invoke_from + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from ) diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index a286c349b2..96dc7dda79 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig): """ Chatbot App Config Entity. """ + pass class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> ChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model @@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager): config_dict = app_model_config_dict.copy() else: if not override_config_dict: - raise Exception('override_config_dict is required when config_from is ARGS') + raise Exception("override_config_dict is required when config_from is ARGS") config_dict = override_config_dict @@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # opening_statement @@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index b784f42e7e..66bc673ab1 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter @@ -30,7 +30,8 @@ logger = logging.getLogger(__name__) class ChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -56,7 +58,8 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) -> Union[dict, Generator[dict | str, None, None]]: ... def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -71,58 +74,46 @@ class ChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -131,7 +122,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -150,14 +141,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -166,17 +154,20 @@ class ChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -190,16 +181,16 @@ class ChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return ChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -221,20 +212,19 @@ class ChatAppGenerator(MessageBasedAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 89a498eb36..425f1ab7ef 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Conversation, Message @@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner): # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -96,15 +96,15 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner): message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner): app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_retrieval = DatasetRetrieval(application_generate_entity) @@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner): files=files, query=query, context=context, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 0ae9926bb8..c7e29686e9 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -22,15 +22,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -44,8 +44,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @@ -62,14 +62,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -92,20 +92,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a771198324..1193c4b7a4 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig): """ Completion App Config Entity. """ + pass class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - override_config_dict: Optional[dict] = None) -> CompletionAppConfig: + def get_app_config( + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + ) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model @@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # text_to_speech @@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 3ce4d3ccaa..2e99e4ef70 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter @@ -32,7 +32,8 @@ logger = logging.getLogger(__name__) class CompletionAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -41,7 +42,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -72,12 +74,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] extras = {} @@ -85,41 +87,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator): conversation = None # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # get tracing instance @@ -137,14 +129,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -153,16 +142,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -176,15 +168,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -203,20 +195,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - message=message + message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -225,12 +216,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator): finally: db.session.close() - def generate_more_like_this(self, app_model: App, - message_id: str, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[str, None, None]]: + def generate_more_like_this( + self, + app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -240,13 +233,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -259,29 +256,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator): app_model_config = message.app_model_config override_model_config_dict = app_model_config.to_dict() - model_dict = override_model_config_dict['model'] - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - override_model_config_dict['model'] = model_dict + model_dict = override_model_config_dict["model"] + completion_params = model_dict.get("completion_params") + completion_params["temperature"] = 0.9 + model_dict["completion_params"] = completion_params + override_model_config_dict["model"] = model_dict # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - message.files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # init application generate entity @@ -295,14 +286,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=stream, invoke_from=invoke_from, - extras={} + extras={}, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -311,16 +299,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -334,7 +325,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index f0e5f9ae17..908d74ff53 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Message @@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message) -> None: + def run( + self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # organize all inputs and template to prompt messages @@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # moderation @@ -77,15 +77,15 @@ class CompletionAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner): app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_config = app_config.dataset @@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner): invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, - message_id=message.id + message_id=message.id, ) # reorganize all inputs and template to prompt messages @@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner): inputs=inputs, files=files, query=query, - context=context + context=context, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) - \ No newline at end of file diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 61bb03952f..77aa1e37a2 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -22,14 +22,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -43,8 +43,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @@ -61,13 +61,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -90,19 +90,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index fceed95b91..c4db95cbd0 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,7 +8,7 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -35,23 +35,23 @@ logger = logging.getLogger(__name__) class MessageBasedAppGenerator(BaseAppGenerator): - def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Handle response. @@ -70,24 +70,25 @@ class MessageBasedAppGenerator(BaseAppGenerator): conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e - def _get_conversation_by_user(self, app_model: App, conversation_id: str, - user: Union[Account, EndUser]) -> Conversation: + def _get_conversation_by_user( + self, app_model: App, conversation_id: str, user: Union[Account, EndUser] + ) -> Conversation: conversation_filter = [ Conversation.id == conversation_id, Conversation.app_id == app_model.id, - Conversation.status == 'normal' + Conversation.status == "normal", ] if isinstance(user, Account): @@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator): if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() return conversation - def _get_app_model_config(self, app_model: App, - conversation: Optional[Conversation] = None) \ - -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .first() + ) if not app_model_config: raise AppModelConfigBrokenError() @@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator): return app_model_config - def _init_generate_records(self, - application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - conversation: Optional[Conversation] = None) \ - -> tuple[Conversation, Message]: + def _init_generate_records( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + conversation: Optional[Conversation] = None, + ) -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity @@ -147,11 +148,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): # get from source end_user_id = None account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' + if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + from_source = "api" end_user_id = application_generate_entity.user_id else: - from_source = 'console' + from_source = "console" account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.COMPLETION, + }: override_model_configs = app_config.app_model_config_dict # get conversation introduction @@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, - name='New conversation', + name="New conversation", inputs=application_generate_entity.inputs, introduction=introduction, system_instruction="", system_instruction_tokens=0, - status='normal', + status="normal", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, @@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, - from_account_id=account_id + from_account_id=account_id, ) db.session.add(message) @@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): message_id=message.id, type=file.type.value, transfer_method=file.transfer_method.value, - belongs_to='user', + belongs_to="user", url=file.url, upload_file_id=file.related_id, - created_by_role=('account' if account_id else 'end_user'), + created_by_role=("account" if account_id else "end_user"), created_by=account_id or end_user_id, ) db.session.add(message_file) @@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param conversation_id: conversation id :return: conversation """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError() @@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param message_id: message id :return: message """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) + message = db.session.query(Message).filter(Message.id == message_id).first() return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index f4ff44ddda..363c3c82bb 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -12,12 +12,9 @@ from core.app.entities.queue_entities import ( class MessageBasedAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: + def __init__( + self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str + ) -> None: super().__init__(task_id, user_id, invoke_from) self._conversation_id = str(conversation_id) @@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager): message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: @@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager): message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueAdvancedChatMessageEndEvent): + if isinstance( + event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() - + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 36d3696d60..8b98e74b85 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): """ Workflow App Config Entity. """ + pass @@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager): app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 7e441b386d..f11fedebc8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -12,7 +12,7 @@ from pydantic import ValidationError import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -34,7 +34,8 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -46,14 +47,15 @@ class WorkflowAppGenerator(BaseAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> dict: ... @overload @@ -76,7 +78,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ): """ Generate App response. @@ -90,26 +92,19 @@ class WorkflowAppGenerator(BaseAppGenerator): :param call_depth: call depth :param workflow_thread_pool_id: workflow thread pool id """ - inputs = args['inputs'] + inputs = args["inputs"] # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -125,7 +120,7 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=stream, invoke_from=invoke_from, call_depth=call_depth, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -136,11 +131,12 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, - workflow_thread_pool_id=workflow_thread_pool_id + workflow_thread_pool_id=workflow_thread_pool_id, ) def _generate( - self, *, + self, + *, app_model: App, workflow: Workflow, user: Union[Account, EndUser], @@ -165,17 +161,20 @@ class WorkflowAppGenerator(BaseAppGenerator): task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode + app_mode=app_model.mode, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'context': contextvars.copy_context(), - 'workflow_thread_pool_id': workflow_thread_pool_id - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() @@ -188,10 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=stream, ) - return WorkflowAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate(self, app_model: App, workflow: Workflow, @@ -210,16 +206,13 @@ class WorkflowAppGenerator(BaseAppGenerator): :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -230,13 +223,10 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras={ - "auto_generate_conversation_name": False - }, + extras={"auto_generate_conversation_name": False}, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -246,14 +236,17 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - stream=stream + stream=stream, ) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - context: contextvars.Context, - workflow_thread_pool_id: Optional[str] = None) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -270,22 +263,21 @@ class WorkflowAppGenerator(BaseAppGenerator): runner = WorkflowAppRunner( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - workflow_thread_pool_id=workflow_thread_pool_id + workflow_thread_pool_id=workflow_thread_pool_id, ) runner.run() - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -294,14 +286,14 @@ class WorkflowAppGenerator(BaseAppGenerator): finally: db.session.close() - def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool = False) -> Union[ - WorkflowAppBlockingResponse, - Generator[WorkflowAppStreamResponse, None, None] - ]: + def _handle_response( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -317,14 +309,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream + stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index f448138b53..76371f800b 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -12,10 +12,7 @@ from core.app.entities.queue_entities import ( class WorkflowAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - app_mode: str) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: super().__init__(task_id, user_id, invoke_from) self._app_mode = app_mode @@ -27,20 +24,19 @@ class WorkflowAppQueueManager(AppQueueManager): :param pub_from: :return: """ - message = WorkflowQueueMessage( - task_id=self._task_id, - app_mode=self._app_mode, - event=event - ) + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent, + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9d48db7546..22ec228fa7 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ def __init__( - self, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - workflow_thread_pool_id: Optional[str] = None + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, ) -> None: """ :param application_generate_entity: application generate entity @@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = cast(WorkflowAppConfig, app_config) user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") db.session.close() workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) # if only single iteration run is requested @@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs + user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) else: - inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files @@ -114,18 +113,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, - thread_pool_id=self.workflow_thread_pool_id + thread_pool_id=self.workflow_thread_pool_id, ) - generator = workflow_entry.run( - callbacks=workflow_callbacks - ) + generator = workflow_entry.run(callbacks=workflow_callbacks) for event in generator: self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 48f20d8dc1..989834ef3b 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -46,12 +46,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -74,12 +74,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 00b3b9f57e..93edf8e0e8 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _workflow: Workflow _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] - def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id + SystemVariableKey.USER_ID: user_id, } self._task_state = WorkflowTaskState() @@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa db.session.refresh(self._user) db.session.close() - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> WorkflowAppBlockingResponse: + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: """ To blocking response. :return: @@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa total_tokens=stream_response.data.total_tokens, total_steps=stream_response.data.total_steps, created_at=int(stream_response.data.created_at), - finished_at=int(stream_response.data.finished_at) - ) + finished_at=int(stream_response.data.finished_at), + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[WorkflowAppStreamResponse, None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: """ To stream response. :return: @@ -158,34 +160,34 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(stream_response, WorkflowStartStreamResponse): workflow_run_id = stream_response.workflow_run_id - yield WorkflowAppStreamResponse( - workflow_run_id=workflow_run_id, - stream_response=stream_response - ) + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -197,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa try: if not tts_publisher: break - audio_trunk = tts_publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) - + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # init workflow run workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start( - workflow_run=workflow_run, - event=event - ) + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, + outputs=json.dumps(event.outputs) + if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs + else None, conversation_id=None, trace_manager=trace_manager, ) @@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), conversation_id=None, trace_manager=trace_manager, @@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueTextChunkEvent): delta_text = event.text @@ -387,14 +376,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text - yield self._text_chunk_to_stream_response(delta_text) + yield self._text_chunk_to_stream_response( + delta_text, from_variable_selector=event.from_variable_selector + ) else: continue if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. @@ -417,14 +407,16 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' + workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" workflow_app_log.created_by = self._user.id db.session.add(workflow_app_log) db.session.commit() db.session.close() - def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: + def _text_chunk_to_stream_response( + self, text: str, from_variable_selector: Optional[list[str]] = None + ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text @@ -432,7 +424,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ response = TextChunkStreamResponse( task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text) + data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), ) return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 1709726887..ce266116a7 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner): """ Init graph """ - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") # init graph - graph = Graph.init( - graph_config=graph_config - ) + graph = Graph.init(graph_config=graph_config) if not graph: - raise ValueError('graph not found in workflow') - + raise ValueError("graph not found in workflow") + return graph def _get_graph_and_variable_pool_of_single_iteration( - self, - workflow: Workflow, - node_id: str, - user_inputs: dict, - ) -> tuple[Graph, VariablePool]: + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: """ Get variable pool of single iteration """ # fetch workflow graph graph_config = workflow.graph_dict if not graph_config: - raise ValueError('workflow graph not found') - + raise ValueError("workflow graph not found") + graph_config = cast(dict[str, Any], graph_config) - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") # filter nodes only in iteration node_configs = [ - node for node in graph_config.get('nodes', []) - if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + node + for node in graph_config.get("nodes", []) + if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id ] - graph_config['nodes'] = node_configs + graph_config["nodes"] = node_configs - node_ids = [node.get('id') for node in node_configs] + node_ids = [node.get("id") for node in node_configs] # filter edges only in iteration edge_configs = [ - edge for edge in graph_config.get('edges', []) - if (edge.get('source') is None or edge.get('source') in node_ids) - and (edge.get('target') is None or edge.get('target') in node_ids) + edge + for edge in graph_config.get("edges", []) + if (edge.get("source") is None or edge.get("source") in node_ids) + and (edge.get("target") is None or edge.get("target") in node_ids) ] - graph_config['edges'] = edge_configs + graph_config["edges"] = edge_configs # init graph - graph = Graph.init( - graph_config=graph_config, - root_node_id=node_id - ) + graph = Graph.init(graph_config=graph_config, root_node_id=node_id) if not graph: - raise ValueError('graph not found in workflow') - + raise ValueError("graph not found in workflow") + # fetch node config from node id iteration_node_config = None for node in node_configs: - if node.get('id') == node_id: + if node.get("id") == node_id: iteration_node_config = node break if not iteration_node_config: - raise ValueError('iteration node id not found in workflow graph') - + raise ValueError("iteration node id not found in workflow graph") + # Get node class - node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) node_cls = cast(type[BaseNode], node_cls) @@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner): try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=iteration_node_config + graph_config=workflow.graph_dict, config=iteration_node_config ) except NotImplementedError: variable_mapping = {} @@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner): variable_pool=variable_pool, tenant_id=workflow.tenant_id, node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get('data', {})) + node_data=IterationNodeData(**iteration_node_config.get("data", {})), ) return graph, variable_pool @@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner): """ if isinstance(event, GraphRunStartedEvent): self._publish_event( - QueueWorkflowStartedEvent( - graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state - ) + QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) ) elif isinstance(event, GraphRunSucceededEvent): - self._publish_event( - QueueWorkflowSucceededEvent(outputs=event.outputs) - ) + self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunFailedEvent): - self._publish_event( - QueueWorkflowFailedEvent(error=event.error) - ) + self._publish_event(QueueWorkflowFailedEvent(error=event.error)) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner): start_at=event.route_node_state.start_at, node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result else {}, - in_iteration_id=event.in_iteration_id + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunFailedEvent): @@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result - and event.route_node_state.node_run_result.error + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunStreamChunkEvent): @@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner): QueueTextChunkEvent( text=event.chunk_content, from_variable_selector=event.from_variable_selector, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): self._publish_event( QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources, - in_iteration_id=event.in_iteration_id + retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id ) ) elif isinstance(event, ParallelBranchRunStartedEvent): @@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner): parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, ParallelBranchRunSucceededEvent): @@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner): parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, ParallelBranchRunFailedEvent): @@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, in_iteration_id=event.in_iteration_id, - error=event.error + error=event.error, ) ) elif isinstance(event, IterationRunStartedEvent): @@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner): node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, predecessor_node_id=event.predecessor_node_id, - metadata=event.metadata + metadata=event.metadata, ) ) elif isinstance(event, IterationRunNextEvent): @@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner): outputs=event.outputs, metadata=event.metadata, steps=event.steps, - error=event.error if isinstance(event, IterationRunFailedEvent) else None + error=event.error if isinstance(event, IterationRunFailedEvent) else None, ) ) @@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner): # return workflow return workflow - + def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) + self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 4e8f3644b1..60683b0f21 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -30,169 +30,150 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): - def __init__(self) -> None: self.current_node_id = None - def on_event( - self, - event: GraphEngineEvent - ) -> None: + def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): - self.print_text("\n[GraphRunStartedEvent]", color='pink') + self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): - self.print_text("\n[GraphRunSucceededEvent]", color='green') + self.print_text("\n[GraphRunSucceededEvent]", color="green") elif isinstance(event, GraphRunFailedEvent): - self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") elif isinstance(event, NodeRunStartedEvent): - self.on_workflow_node_execute_started( - event=event - ) + self.on_workflow_node_execute_started(event=event) elif isinstance(event, NodeRunSucceededEvent): - self.on_workflow_node_execute_succeeded( - event=event - ) + self.on_workflow_node_execute_succeeded(event=event) elif isinstance(event, NodeRunFailedEvent): - self.on_workflow_node_execute_failed( - event=event - ) + self.on_workflow_node_execute_failed(event=event) elif isinstance(event, NodeRunStreamChunkEvent): - self.on_node_text_chunk( - event=event - ) + self.on_node_text_chunk(event=event) elif isinstance(event, ParallelBranchRunStartedEvent): - self.on_workflow_parallel_started( - event=event - ) + self.on_workflow_parallel_started(event=event) elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): - self.on_workflow_parallel_completed( - event=event - ) + self.on_workflow_parallel_completed(event=event) elif isinstance(event, IterationRunStartedEvent): - self.on_workflow_iteration_started( - event=event - ) + self.on_workflow_iteration_started(event=event) elif isinstance(event, IterationRunNextEvent): - self.on_workflow_iteration_next( - event=event - ) + self.on_workflow_iteration_next(event=event) elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): - self.on_workflow_iteration_completed( - event=event - ) + self.on_workflow_iteration_completed(event=event) else: - self.print_text(f"\n[{event.__class__.__name__}]", color='blue') + self.print_text(f"\n[{event.__class__.__name__}]", color="blue") - def on_workflow_node_execute_started( - self, - event: NodeRunStartedEvent - ) -> None: + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: """ Workflow node execute started """ - self.print_text("\n[NodeRunStartedEvent]", color='yellow') - self.print_text(f"Node ID: {event.node_id}", color='yellow') - self.print_text(f"Node Title: {event.node_data.title}", color='yellow') - self.print_text(f"Type: {event.node_type.value}", color='yellow') + self.print_text("\n[NodeRunStartedEvent]", color="yellow") + self.print_text(f"Node ID: {event.node_id}", color="yellow") + self.print_text(f"Node Title: {event.node_data.title}", color="yellow") + self.print_text(f"Type: {event.node_type.value}", color="yellow") - def on_workflow_node_execute_succeeded( - self, - event: NodeRunSucceededEvent - ) -> None: + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: """ Workflow node execute succeeded """ route_node_state = event.route_node_state - self.print_text("\n[NodeRunSucceededEvent]", color='green') - self.print_text(f"Node ID: {event.node_id}", color='green') - self.print_text(f"Node Title: {event.node_data.title}", color='green') - self.print_text(f"Type: {event.node_type.value}", color='green') + self.print_text("\n[NodeRunSucceededEvent]", color="green") + self.print_text(f"Node ID: {event.node_id}", color="green") + self.print_text(f"Node Title: {event.node_data.title}", color="green") + self.print_text(f"Type: {event.node_type.value}", color="green") if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result - self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", - color='green') self.print_text( - f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", - color='green') - self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", - color='green') + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color="green", + ) + self.print_text( + f"Process Data: " + f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="green", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="green", + ) self.print_text( f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", - color='green') + color="green", + ) - def on_workflow_node_execute_failed( - self, - event: NodeRunFailedEvent - ) -> None: + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: """ Workflow node execute failed """ route_node_state = event.route_node_state - self.print_text("\n[NodeRunFailedEvent]", color='red') - self.print_text(f"Node ID: {event.node_id}", color='red') - self.print_text(f"Node Title: {event.node_data.title}", color='red') - self.print_text(f"Type: {event.node_type.value}", color='red') + self.print_text("\n[NodeRunFailedEvent]", color="red") + self.print_text(f"Node ID: {event.node_id}", color="red") + self.print_text(f"Node Title: {event.node_data.title}", color="red") + self.print_text(f"Type: {event.node_type.value}", color="red") if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result - self.print_text(f"Error: {node_run_result.error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", - color='red') + self.print_text(f"Error: {node_run_result.error}", color="red") self.print_text( - f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", - color='red') - self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", - color='red') + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color="red", + ) + self.print_text( + f"Process Data: " + f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="red", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="red", + ) - def on_node_text_chunk( - self, - event: NodeRunStreamChunkEvent - ) -> None: + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ Publish text chunk """ route_node_state = event.route_node_state if not self.current_node_id or self.current_node_id != route_node_state.node_id: self.current_node_id = route_node_state.node_id - self.print_text('\n[NodeRunStreamChunkEvent]') + self.print_text("\n[NodeRunStreamChunkEvent]") self.print_text(f"Node ID: {route_node_state.node_id}") node_run_result = route_node_state.node_run_result if node_run_result: self.print_text( - f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}" + ) self.print_text(event.chunk_content, color="pink", end="") - def on_workflow_parallel_started( - self, - event: ParallelBranchRunStartedEvent - ) -> None: + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: """ Publish parallel started """ - self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') - self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') - self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') + self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue") + self.print_text(f"Parallel ID: {event.parallel_id}", color="blue") + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue") if event.in_iteration_id: - self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') + self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue") def on_workflow_parallel_completed( - self, - event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent ) -> None: """ Publish parallel completed """ if isinstance(event, ParallelBranchRunSucceededEvent): - color = 'blue' + color = "blue" elif isinstance(event, ParallelBranchRunFailedEvent): - color = 'red' + color = "red" - self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) + self.print_text( + "\n[ParallelBranchRunSucceededEvent]" + if isinstance(event, ParallelBranchRunSucceededEvent) + else "\n[ParallelBranchRunFailedEvent]", + color=color, + ) self.print_text(f"Parallel ID: {event.parallel_id}", color=color) self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) if event.in_iteration_id: @@ -201,43 +182,37 @@ class WorkflowLoggingCallback(WorkflowCallback): if isinstance(event, ParallelBranchRunFailedEvent): self.print_text(f"Error: {event.error}", color=color) - def on_workflow_iteration_started( - self, - event: IterationRunStartedEvent - ) -> None: + def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: """ Publish iteration started """ - self.print_text("\n[IterationRunStartedEvent]", color='blue') - self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') + self.print_text("\n[IterationRunStartedEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") - def on_workflow_iteration_next( - self, - event: IterationRunNextEvent - ) -> None: + def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: """ Publish iteration next """ - self.print_text("\n[IterationRunNextEvent]", color='blue') - self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') - self.print_text(f"Iteration Index: {event.index}", color='blue') + self.print_text("\n[IterationRunNextEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + self.print_text(f"Iteration Index: {event.index}", color="blue") - def on_workflow_iteration_completed( - self, - event: IterationRunSucceededEvent | IterationRunFailedEvent - ) -> None: + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: """ Publish iteration completed """ - self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') - self.print_text(f"Node ID: {event.iteration_id}", color='blue') + self.print_text( + "\n[IterationRunSucceededEvent]" + if isinstance(event, IterationRunSucceededEvent) + else "\n[IterationRunFailedEvent]", + color="blue", + ) + self.print_text(f"Node ID: {event.iteration_id}", color="blue") - def print_text( - self, text: str, color: Optional[str] = None, end: str = "\n" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text - print(f'{text_to_print}', end=end) + print(f"{text_to_print}", end=end) def _get_colored_text(self, text: str, color: str) -> str: """Get colored text.""" diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6a1ab23041..ab8d4e374e 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -15,13 +15,14 @@ class InvokeFrom(Enum): """ Invoke From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': + def value_of(cls, value: str) -> "InvokeFrom": """ Get value of given mode. @@ -31,7 +32,7 @@ class InvokeFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid invoke from value {value}') + raise ValueError(f"invalid invoke from value {value}") def to_source(self) -> str: """ @@ -40,21 +41,22 @@ class InvokeFrom(Enum): :return: source """ if self == InvokeFrom.WEB_APP: - return 'web_app' + return "web_app" elif self == InvokeFrom.DEBUGGER: - return 'dev' + return "dev" elif self == InvokeFrom.EXPLORE: - return 'explore_app' + return "explore_app" elif self == InvokeFrom.SERVICE_API: - return 'api' + return "api" - return 'dev' + return "dev" class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. """ + provider: str model: str model_schema: AIModelEntity @@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel): """ App Generate Entity. """ + task_id: str # app config @@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ Chat Application Generate Entity. """ + # app config app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity @@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Completion Application Generate Entity. """ + pass @@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Agent Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Advanced Chat Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Single Iteration Run Entity. """ + node_id: str inputs: dict single_iteration_run: Optional[SingleIterationRunEntity] = None + class WorkflowAppGenerateEntity(AppGenerateEntity): """ Workflow Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ Single Iteration Run Entity. """ + node_id: str inputs: dict diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 4c86b7eee1..4577e28535 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -14,6 +14,7 @@ class QueueEvent(str, Enum): """ QueueEvent enum """ + LLM_CHUNK = "llm_chunk" TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" @@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel): """ QueueEvent abstract entity """ + event: QueueEvent @@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent): QueueLLMChunkEvent entity Only for basic mode apps """ + event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + class QueueIterationStartEvent(AppQueueEvent): """ QueueIterationStartEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_START node_execution_id: str node_id: str @@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent): predecessor_node_id: Optional[str] = None metadata: Optional[dict[str, Any]] = None + class QueueIterationNextEvent(AppQueueEvent): """ QueueIterationNextEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_NEXT index: int @@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel start node id if node is in parallel""" node_run_index: int - output: Optional[Any] = None # output for the current iteration + output: Optional[Any] = None # output for the current iteration - @field_validator('output', mode='before') + @field_validator("output", mode="before") @classmethod def set_output(cls, v): """ @@ -113,12 +120,14 @@ class QueueIterationNextEvent(AppQueueEvent): return None if isinstance(v, int | float | str | bool | dict | list): return v - raise ValueError('output must be a valid type') + raise ValueError("output must be a valid type") + class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_COMPLETED node_execution_id: str @@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent): parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" start_at: datetime - + node_run_index: int inputs: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None @@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity """ + event: QueueEvent = QueueEvent.TEXT_CHUNK text: str from_variable_selector: Optional[list[str]] = None @@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event: QueueEvent = QueueEvent.AGENT_MESSAGE chunk: LLMResultChunk - + class QueueMessageReplaceEvent(AppQueueEvent): """ QueueMessageReplaceEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_REPLACE text: str @@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ QueueRetrieverResourcesEvent entity """ + event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] in_iteration_id: Optional[str] = None @@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent): """ QueueAnnotationReplyEvent entity """ + event: QueueEvent = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ QueueMessageEndEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_END llm_result: Optional[LLMResult] = None @@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent): """ QueueAdvancedChatMessageEndEvent entity """ + event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END @@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """ QueueWorkflowStartedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_STARTED graph_runtime_state: GraphRuntimeState @@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ QueueWorkflowSucceededEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED outputs: Optional[dict[str, Any]] = None @@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent): """ QueueWorkflowFailedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str @@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent): """ QueueNodeStartedEvent entity """ + event: QueueEvent = QueueEvent.NODE_STARTED node_execution_id: str @@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): """ QueueNodeSucceededEvent entity """ + event: QueueEvent = QueueEvent.NODE_SUCCEEDED node_execution_id: str @@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity """ + event: QueueEvent = QueueEvent.NODE_FAILED node_execution_id: str @@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.AGENT_THOUGHT agent_thought_id: str @@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_FILE message_file_id: str @@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity """ + event: QueueEvent = QueueEvent.ERROR error: Any = None @@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent): """ QueuePingEvent entity """ + event: QueueEvent = QueueEvent.PING @@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent): """ QueueStopEvent entity """ + class StopBy(Enum): """ Stop by enum """ + USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" @@ -365,19 +393,20 @@ class QueueStopEvent(AppQueueEvent): To stop reason """ reason_mapping = { - QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', - QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', - QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', - QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' + QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.", + QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.", + QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.", + QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.", } - return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') + return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") class QueueMessage(BaseModel): """ QueueMessage abstract entity """ + task_id: str app_mode: str event: AppQueueEvent @@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage): """ MessageQueueMessage entity """ + message_id: str conversation_id: str @@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage): """ WorkflowQueueMessage entity """ + pass @@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent): """ QueueParallelBranchRunStartedEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED parallel_id: str @@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent): """ QueueParallelBranchRunSucceededEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED parallel_id: str @@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent): """ QueueParallelBranchRunFailedEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED parallel_id: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7cab6ca4e0..49e5f55ebc 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -12,6 +12,7 @@ class TaskState(BaseModel): """ TaskState entity """ + metadata: dict = {} @@ -19,6 +20,7 @@ class EasyUITaskState(TaskState): """ EasyUITaskState entity """ + llm_result: LLMResult @@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState): """ WorkflowTaskState entity """ + answer: str = "" @@ -33,6 +36,7 @@ class StreamEvent(Enum): """ Stream event """ + PING = "ping" ERROR = "error" MESSAGE = "message" @@ -60,6 +64,7 @@ class StreamResponse(BaseModel): """ StreamResponse entity """ + event: StreamEvent task_id: str @@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse): """ ErrorStreamResponse entity """ + event: StreamEvent = StreamEvent.ERROR err: Exception model_config = ConfigDict(arbitrary_types_allowed=True) @@ -80,15 +86,18 @@ class MessageStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE id: str answer: str + from_variable_selector: Optional[list[str]] = None class MessageAudioStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE audio: str @@ -97,6 +106,7 @@ class MessageAudioEndStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE_END audio: str @@ -105,6 +115,7 @@ class MessageEndStreamResponse(StreamResponse): """ MessageEndStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} @@ -114,6 +125,7 @@ class MessageFileStreamResponse(StreamResponse): """ MessageFileStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_FILE id: str type: str @@ -125,6 +137,7 @@ class MessageReplaceStreamResponse(StreamResponse): """ MessageReplaceStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_REPLACE answer: str @@ -133,6 +146,7 @@ class AgentThoughtStreamResponse(StreamResponse): """ AgentThoughtStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int @@ -148,6 +162,7 @@ class AgentMessageStreamResponse(StreamResponse): """ AgentMessageStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_MESSAGE id: str answer: str @@ -162,6 +177,7 @@ class WorkflowStartStreamResponse(StreamResponse): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -182,6 +198,7 @@ class WorkflowFinishStreamResponse(StreamResponse): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -210,6 +227,7 @@ class NodeStartStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -249,7 +267,7 @@ class NodeStartStreamResponse(StreamResponse): "parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, - } + }, } @@ -262,6 +280,7 @@ class NodeFinishStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -315,9 +334,9 @@ class NodeFinishStreamResponse(StreamResponse): "parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, - } + }, } - + class ParallelBranchStartStreamResponse(StreamResponse): """ @@ -328,6 +347,7 @@ class ParallelBranchStartStreamResponse(StreamResponse): """ Data entity """ + parallel_id: str parallel_branch_id: str parent_parallel_id: Optional[str] = None @@ -349,6 +369,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse): """ Data entity """ + parallel_id: str parallel_branch_id: str parent_parallel_id: Optional[str] = None @@ -372,6 +393,7 @@ class IterationNodeStartStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -397,6 +419,7 @@ class IterationNodeNextStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -422,6 +445,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -454,7 +478,9 @@ class TextChunkStreamResponse(StreamResponse): """ Data entity """ + text: str + from_variable_selector: Optional[list[str]] = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -469,6 +495,7 @@ class TextReplaceStreamResponse(StreamResponse): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_REPLACE @@ -479,6 +506,7 @@ class PingStreamResponse(StreamResponse): """ PingStreamResponse entity """ + event: StreamEvent = StreamEvent.PING @@ -486,6 +514,7 @@ class AppStreamResponse(BaseModel): """ AppStreamResponse entity """ + stream_response: StreamResponse @@ -493,6 +522,7 @@ class ChatbotAppStreamResponse(AppStreamResponse): """ ChatbotAppStreamResponse entity """ + conversation_id: str message_id: str created_at: int @@ -502,6 +532,7 @@ class CompletionAppStreamResponse(AppStreamResponse): """ CompletionAppStreamResponse entity """ + message_id: str created_at: int @@ -510,6 +541,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ + workflow_run_id: Optional[str] = None @@ -517,6 +549,7 @@ class AppBlockingResponse(BaseModel): """ AppBlockingResponse entity """ + task_id: str def to_dict(self) -> dict: @@ -532,6 +565,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str mode: str conversation_id: str @@ -552,6 +586,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str mode: str message_id: str @@ -571,6 +606,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str workflow_id: str status: str diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 19ff94de5e..77b6bb554c 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -13,11 +13,9 @@ logger = logging.getLogger(__name__) class AnnotationReplyFeature: - def query(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -27,8 +25,9 @@ class AnnotationReplyFeature: :param invoke_from: invoke from :return: """ - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_record.id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + ) if not annotation_setting: return None @@ -41,55 +40,50 @@ class AnnotationReplyFeature: embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" ) dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) documents = vector.search_by_vector( - query=query, - top_k=1, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) if documents: - annotation_id = documents[0].metadata['annotation_id'] - score = documents[0].metadata['score'] + annotation_id = documents[0].metadata["annotation_id"] + score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: - if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: - from_source = 'api' + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: + from_source = "api" else: - from_source = 'console' + from_source = "console" # insert annotation history - AppAnnotationService.add_annotation_history(annotation.id, - app_record.id, - annotation.question, - annotation.content, - query, - user_id, - message.id, - from_source, - score) + AppAnnotationService.add_annotation_history( + annotation.id, + app_record.id, + annotation.question, + annotation.content, + query, + user_id, + message.id, + from_source, + score, + ) return annotation except Exception as e: - logger.warning(f'Query annotation failed, exception: {str(e)}.') + logger.warning(f"Query annotation failed, exception: {str(e)}.") return None return None diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index b8f3e0e1f6..ba14b61201 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -8,8 +8,9 @@ logger = logging.getLogger(__name__) class HostingModerationFeature: - def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list[PromptMessage]) -> bool: + def check( + self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage] + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -23,9 +24,6 @@ class HostingModerationFeature: if isinstance(prompt_message.content, str): text += prompt_message.content + "\n" - moderation_result = moderation.check_moderation( - model_config, - text - ) + moderation_result = moderation.check_moderation(model_config, text) return moderation_result diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 570e3c003f..47643cfcbc 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict = {} - def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): + def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance @@ -27,13 +27,13 @@ class RateLimit: def __init__(self, client_id: str, max_active_requests: int): self.max_active_requests = max_active_requests - if hasattr(self, 'initialized'): + if hasattr(self, "initialized"): return self.initialized = True self.client_id = client_id self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) - self.last_recalculate_time = float('-inf') + self.last_recalculate_time = float("-inf") self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): @@ -46,7 +46,7 @@ class RateLimit: pipe.execute() else: with redis_client.pipeline() as pipe: - self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8')) + self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) # flush max active requests (in-transit request list) @@ -54,8 +54,11 @@ class RateLimit: return request_details = redis_client.hgetall(self.active_requests_key) redis_client.expire(self.active_requests_key, timedelta(days=1)) - timeout_requests = [k for k, v in request_details.items() if - time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME] + timeout_requests = [ + k + for k, v in request_details.items() + if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME + ] if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) @@ -69,8 +72,10 @@ class RateLimit: active_requests_count = redis_client.hlen(self.active_requests_key) if active_requests_count >= self.max_active_requests: - raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum " - "concurrent requests allowed is {}.".format(self.max_active_requests)) + raise AppInvokeQuotaExceededError( + "Too many requests. Please try again later. The current maximum " + "concurrent requests allowed is {}.".format(self.max_active_requests) + ) redis_client.hset(self.active_requests_key, request_id, str(time.time())) return request_id @@ -116,5 +121,5 @@ class RateLimitGenerator: if not self.closed: self.closed = True self.rate_limit.exit(self.request_id) - if self.generator is not None and hasattr(self.generator, 'close'): + if self.generator is not None and hasattr(self.generator, "close"): self.generator.close() diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 7de06dfb96..652ef243b4 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -25,25 +25,25 @@ from .variables import ( ) __all__ = [ - 'IntegerVariable', - 'FloatVariable', - 'ObjectVariable', - 'SecretVariable', - 'StringVariable', - 'ArrayAnyVariable', - 'Variable', - 'SegmentType', - 'SegmentGroup', - 'Segment', - 'NoneSegment', - 'NoneVariable', - 'IntegerSegment', - 'FloatSegment', - 'ObjectSegment', - 'ArrayAnySegment', - 'StringSegment', - 'ArrayStringVariable', - 'ArrayNumberVariable', - 'ArrayObjectVariable', - 'ArraySegment', + "IntegerVariable", + "FloatVariable", + "ObjectVariable", + "SecretVariable", + "StringVariable", + "ArrayAnyVariable", + "Variable", + "SegmentType", + "SegmentGroup", + "Segment", + "NoneSegment", + "NoneVariable", + "IntegerSegment", + "FloatSegment", + "ObjectSegment", + "ArrayAnySegment", + "StringSegment", + "ArrayStringVariable", + "ArrayNumberVariable", + "ArrayObjectVariable", + "ArraySegment", ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index e6e9ce9774..40a69ed4eb 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -28,12 +28,12 @@ from .variables import ( def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if (value_type := mapping.get('value_type')) is None: - raise VariableError('missing value type') - if not mapping.get('name'): - raise VariableError('missing name') - if (value := mapping.get('value')) is None: - raise VariableError('missing value') + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if not mapping.get("name"): + raise VariableError("missing name") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.NUMBER if isinstance(value, float): result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f'invalid number value {value}') + raise VariableError(f"invalid number value {value}") case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): @@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) case _: - raise VariableError(f'not supported value type {value_type}') + raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") return result @@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, list): return ArrayAnySegment(value=value) - raise ValueError(f'not supported value {value}') + raise ValueError(f"not supported value {value}") diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py index de6c796652..3c4d7046f4 100644 --- a/api/core/app/segments/parser.py +++ b/api/core/app/segments/parser.py @@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool from . import SegmentGroup, factory -VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") def convert_template(*, template: str, variable_pool: VariablePool): parts = re.split(VARIABLE_PATTERN, template) segments = [] for part in filter(lambda x: x, parts): - if '.' in part and (value := variable_pool.get(part.split('.'))): + if "." in part and (value := variable_pool.get(part.split("."))): segments.append(value) else: segments.append(factory.build_segment(part)) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py index b4ff09b6d3..b363255b2c 100644 --- a/api/core/app/segments/segment_group.py +++ b/api/core/app/segments/segment_group.py @@ -8,15 +8,15 @@ class SegmentGroup(Segment): @property def text(self): - return ''.join([segment.text for segment in self.value]) + return "".join([segment.text for segment in self.value]) @property def log(self): - return ''.join([segment.log for segment in self.value]) + return "".join([segment.log for segment in self.value]) @property def markdown(self): - return ''.join([segment.markdown for segment in self.value]) + return "".join([segment.markdown for segment in self.value]) def to_object(self): return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 5c713cac67..b26b3c8291 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -14,13 +14,14 @@ class Segment(BaseModel): value_type: SegmentType value: Any - @field_validator('value_type') + @field_validator("value_type") + @classmethod def validate_value_type(cls, value): """ This validator checks if the provided value is equal to the default value of the 'value_type' field. If the value is different, a ValueError is raised. """ - if value != cls.model_fields['value_type'].default: + if value != cls.model_fields["value_type"].default: raise ValueError("Cannot modify 'value_type'") return value @@ -50,15 +51,15 @@ class NoneSegment(Segment): @property def text(self) -> str: - return 'null' + return "null" @property def log(self) -> str: - return 'null' + return "null" @property def markdown(self) -> str: - return 'null' + return "null" class StringSegment(Segment): @@ -76,24 +77,21 @@ class IntegerSegment(Segment): value: int - - - class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT value: Mapping[str, Any] @property def text(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False) + return json.dumps(self.model_dump()["value"], ensure_ascii=False) @property def log(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) @property def markdown(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) class ArraySegment(Segment): @@ -101,11 +99,11 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, 'to_markdown'): + if hasattr(item, "to_markdown"): items.append(item.to_markdown()) else: items.append(str(item)) - return '\n'.join(items) + return "\n".join(items) class ArrayAnySegment(ArraySegment): @@ -126,4 +124,3 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] - diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index cdd2b0b4b0..9cf0856df5 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -2,14 +2,14 @@ from enum import Enum class SegmentType(str, Enum): - NONE = 'none' - NUMBER = 'number' - STRING = 'string' - SECRET = 'secret' - ARRAY_ANY = 'array[any]' - ARRAY_STRING = 'array[string]' - ARRAY_NUMBER = 'array[number]' - ARRAY_OBJECT = 'array[object]' - OBJECT = 'object' + NONE = "none" + NUMBER = "number" + STRING = "string" + SECRET = "secret" + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + OBJECT = "object" - GROUP = 'group' + GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index 8fef707fcf..f0e403ab8d 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -23,11 +23,11 @@ class Variable(Segment): """ id: str = Field( - default='', + default="", description="Unique identity for variable. It's only used by environment variables now.", ) name: str - description: str = Field(default='', description='Description of the variable.') + description: str = Field(default="", description="Description of the variable.") class StringVariable(StringSegment, Variable): @@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 2f74a180d1..a43be5fdf2 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline: _task_state: TaskState _application_generate_entity: AppGenerateEntity - def __init__(self, application_generate_entity: AppGenerateEntity, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline: e = event.error if isinstance(e, InvokeAuthorizationError): - err = InvokeAuthorizationError('Incorrect API key provided') - elif isinstance(e, InvokeError) or isinstance(e, ValueError): + err = InvokeAuthorizationError("Incorrect API key provided") + elif isinstance(e, InvokeError | ValueError): err = e else: - err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) if message: refetch_message = db.session.query(Message).filter(Message.id == message.id).first() if refetch_message: err_desc = self._error_to_desc(err) - refetch_message.status = 'error' + refetch_message.status = "error" refetch_message.error = err_desc db.session.commit() @@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline: :return: """ if isinstance(e, QuotaExceededError): - return ("Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.") + return ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) - message = getattr(e, 'description', str(e)) + message = getattr(e, "description", str(e)) if not message: - message = 'Internal Server Error, please contact support.' + message = "Internal Server Error, please contact support." return message @@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline: :param e: exception :return: """ - return ErrorStreamResponse( - task_id=self._application_generate_entity.task_id, - err=e - ) + return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) def _ping_stream_response(self) -> PingStreamResponse: """ @@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline: return OutputModeration( tenant_id=app_config.tenant_id, app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager + rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), + queue_manager=self._queue_manager, ) def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: @@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline: self._output_moderation_handler.stop_thread() completion = self._output_moderation_handler.moderation_completion( - completion=completion, - public_event=False + completion=completion, public_event=False ) self._output_moderation_handler = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 8d91a507a9..8f834b6458 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: EasyUITaskState - _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ] - def __init__(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool) -> None: + _task_state: EasyUITaskState + _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + + def __init__( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), - usage=LLMUsage.empty_usage() + usage=LLMUsage.empty_usage(), ) ) self._conversation_name_generate_thread = None def process( - self, + self, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Process generate task pipeline. @@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[ - ChatbotAppBlockingResponse, - CompletionAppBlockingResponse - ]: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: """ Process blocking response. :return: @@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = { - 'usage': jsonable_encoder(self._task_state.llm_result.usage) - } + extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: response = ChatbotAppBlockingResponse( @@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: """ To stream response. :return: @@ -198,37 +191,41 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan yield CompletionAppStreamResponse( message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') - if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): - publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) + text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + if ( + text_to_speech_dict + and text_to_speech_dict.get("autoPlay") == "enabled" + and text_to_speech_dict.get("enabled") + ): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id) + audio_response = self._listen_audio_msg(publisher, task_id) if audio_response: yield audio_response else: @@ -240,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: if publisher is None: break - audio = publisher.checkAndGetAudio() + audio = publisher.check_and_get_audio() if audio is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan break else: start_listener_time = time.time() - yield MessageAudioStreamResponse(audio=audio.audio, - task_id=task_id) - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message( - self, trace_manager: Optional[TraceQueueManager] = None - ) -> None: + def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( - self._model_config.mode, - self._task_state.llm_result.prompt_messages + self._model_config.mode, self._task_state.llm_result.prompt_messages ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ - if llm_result.message.content else '' + self._message.answer = ( + PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + if llm_result.message.content + else "" + ) self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, - conversation_id=self._conversation.id, - message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id ) ) @@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [ - AppMode.AGENT_CHAT, - AppMode.CHAT - ] and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} + and self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model = model_config.model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_instance.get_llm_num_tokens( - self._task_state.llm_result.prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_instance.get_llm_num_tokens( - [self._task_state.llm_result.message] - ) + completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message]) credentials = model_config.credentials @@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens + model, credentials, prompt_tokens, completion_tokens ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan Message end to stream response. :return: """ - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :return: """ return AgentMessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer + task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: @@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :return: """ agent_thought: MessageAgentThought = ( - db.session.query(MessageAgentThought) - .filter(MessageAgentThought.id == event.agent_thought_id) - .first() + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) db.session.close() @@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan tool=agent_thought.tool, tool_labels=agent_thought.tool_labels, tool_input=agent_thought.tool_input, - message_files=agent_thought.files + message_files=agent_thought.files, ) return None @@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content), + ), ) - ), PublishFrom.TASK_PIPELINE + ), + PublishFrom.TASK_PIPELINE, ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 8ff50dd174..5872e00740 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService class MessageCycleManage: _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] _task_state: Union[EasyUITaskState, WorkflowTaskState] @@ -49,15 +46,18 @@ class MessageCycleManage: is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) if auto_generate_conversation_name and is_first_message: # start generate thread - thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'conversation_id': conversation.id, - 'query': query - }) + thread = Thread( + target=self._generate_conversation_name_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "conversation_id": conversation.id, + "query": query, + }, + ) thread.start() @@ -65,17 +65,10 @@ class MessageCycleManage: return None - def _generate_conversation_name_worker(self, - flask_app: Flask, - conversation_id: str, - query: str): + def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: return @@ -105,12 +98,9 @@ class MessageCycleManage: annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } + self._task_state.metadata["annotation_reply"] = { + "id": annotation.id, + "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, } return annotation @@ -124,7 +114,7 @@ class MessageCycleManage: :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata['retriever_resources'] = event.retriever_resources + self._task_state.metadata["retriever_resources"] = event.retriever_resources def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ @@ -132,27 +122,23 @@ class MessageCycleManage: :param event: event :return: """ - message_file = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) + message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() if message_file: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url to local file - if message_file.url.startswith('http'): + if message_file.url.startswith("http"): url = message_file.url else: url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) @@ -161,13 +147,15 @@ class MessageCycleManage: task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or 'user', - url=url + belongs_to=message_file.belongs_to or "user", + url=url, ) return None - def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse: + def _message_to_stream_response( + self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer @@ -177,7 +165,8 @@ class MessageCycleManage: return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, - answer=answer + answer=answer, + from_variable_selector=from_variable_selector, ) def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: @@ -186,7 +175,4 @@ class MessageCycleManage: :param answer: answer :return: """ - return MessageReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - answer=answer - ) + return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ed3225310a..f10189798f 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -70,14 +70,14 @@ class WorkflowCycleManage: inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): - if key.value == 'conversation': + if key.value == "conversation": continue - inputs[f'sys.{key.value}'] = value + inputs[f"sys.{key.value}"] = value inputs = WorkflowEntry.handle_special_values(inputs) - triggered_from= ( + triggered_from = ( WorkflowRunTriggeredFrom.DEBUGGING if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN @@ -185,20 +185,26 @@ class WorkflowCycleManage: db.session.commit() - running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value - ).all() + running_workflow_node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + ) + .all() + ) for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() + workflow_node_execution.elapsed_time = ( + workflow_node_execution.finished_at - workflow_node_execution.created_at + ).total_seconds() db.session.commit() db.session.refresh(workflow_run) @@ -216,7 +222,9 @@ class WorkflowCycleManage: return workflow_run - def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + def _handle_node_execution_start( + self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + ) -> WorkflowNodeExecution: # init workflow node execution workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.tenant_id = workflow_run.tenant_id @@ -333,16 +341,16 @@ class WorkflowCycleManage: created_by_account = workflow_run.created_by_account if created_by_account: created_by = { - 'id': created_by_account.id, - 'name': created_by_account.name, - 'email': created_by_account.email, + "id": created_by_account.id, + "name": created_by_account.name, + "email": created_by_account.email, } else: created_by_end_user = workflow_run.created_by_end_user if created_by_end_user: created_by = { - 'id': created_by_end_user.id, - 'user': created_by_end_user.session_id, + "id": created_by_end_user.id, + "user": created_by_end_user.session_id, } return WorkflowFinishStreamResponse( @@ -375,7 +383,7 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None response = NodeStartStreamResponse( @@ -401,7 +409,7 @@ class WorkflowCycleManage: # extras logic if event.node_type == NodeType.TOOL: node_data = cast(ToolNodeData, event.node_data) - response.data.extras['icon'] = ToolManager.get_tool_icon( + response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, provider_id=node_data.provider_id, @@ -410,10 +418,10 @@ class WorkflowCycleManage: return response def _workflow_node_finish_to_stream_response( - self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. @@ -422,9 +430,9 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None - + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -452,13 +460,10 @@ class WorkflowCycleManage: iteration_id=event.in_iteration_id, ), ) - + def _workflow_parallel_branch_start_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunStartedEvent - ) -> ParallelBranchStartStreamResponse: + self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: """ Workflow parallel branch start to stream response :param task_id: task id @@ -476,15 +481,15 @@ class WorkflowCycleManage: parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, created_at=int(time.time()), - ) + ), ) - + def _workflow_parallel_branch_finished_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent - ) -> ParallelBranchFinishedStreamResponse: + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: """ Workflow parallel branch finished to stream response :param task_id: task id @@ -501,18 +506,15 @@ class WorkflowCycleManage: parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, - status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, created_at=int(time.time()), - ) + ), ) def _workflow_iteration_start_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueIterationStartEvent - ) -> IterationNodeStartStreamResponse: + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: """ Workflow iteration start to stream response :param task_id: task id @@ -534,10 +536,12 @@ class WorkflowCycleManage: metadata=event.metadata or {}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) - def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: + def _workflow_iteration_next_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + ) -> IterationNodeNextStreamResponse: """ Workflow iteration next to stream response :param task_id: task id @@ -559,10 +563,12 @@ class WorkflowCycleManage: extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) - def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: + def _workflow_iteration_completed_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + ) -> IterationNodeCompletedStreamResponse: """ Workflow iteration completed to stream response :param task_id: task id @@ -585,13 +591,13 @@ class WorkflowCycleManage: status=WorkflowNodeExecutionStatus.SUCCEEDED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), - total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), steps=event.steps, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: @@ -643,7 +649,7 @@ class WorkflowCycleManage: return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() @@ -656,11 +662,10 @@ class WorkflowCycleManage: :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() if not workflow_run: - raise Exception(f'Workflow run not found: {workflow_run_id}') + raise Exception(f"Workflow run not found: {workflow_run_id}") return workflow_run @@ -683,6 +688,6 @@ class WorkflowCycleManage: ) if not workflow_node_execution: - raise Exception(f'Workflow node execution not found: {node_execution_id}') + raise Exception(f"Workflow node execution not found: {node_execution_id}") - return workflow_node_execution \ No newline at end of file + return workflow_node_execution diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 4c246b230d..f31b43cb7f 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = { "red": "31;1", } + def get_colored_text(text: str, color: str) -> str: """Get colored text.""" color_str = _TEXT_COLOR_MAPPING[color] return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) if file: file.flush() # ensure all printed content are written to file + class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = '' + + color: Optional[str] = "" current_loop: int = 1 def __init__(self, color: Optional[str] = None) -> None: super().__init__() """Initialize callback handler.""" # use a specific color is not specified - self.color = color or 'green' + self.color = color or "green" self.current_loop = 1 def on_tool_start( @@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel): tool_outputs: Iterable[ToolInvokeMessage] | str, message_id: Optional[str] = None, timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> None: """If not the final action, print out observation.""" print_text("\n[on_tool_end]\n", color=self.color) @@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel): ) ) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") - def on_agent_start( - self, thought: str - ) -> None: + def on_agent_start(self, thought: str) -> None: """Run on agent start.""" if thought: - print_text("\n[on_agent_start] \nCurrent Loop: " + \ - str(self.current_loop) + \ - "\nThought: " + thought + "\n", color=self.color) + print_text( + "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n", + color=self.color, + ) else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish( - self, color: Optional[str] = None, **kwargs: Any - ) -> None: + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: """Run on agent end.""" print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) @@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel): @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8e1f496b22..7cf472d984 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ - from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: AppQueueManager, - app_id: str, - message_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__( + self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom + ) -> None: self._queue_manager = queue_manager self._app_id = app_id self._message_id = message_id @@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=self._app_id, - created_by_role=('account' - if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), - created_by=self._user_id + created_by_role=( + "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + ), + created_by=self._user_id, ) db.session.add(dataset_query) @@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() @@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler: for item in resource: dataset_retriever_resource = DatasetRetrieverResource( message_id=self._message_id, - position=item.get('position'), - dataset_id=item.get('dataset_id'), - dataset_name=item.get('dataset_name'), - document_id=item.get('document_id'), - document_name=item.get('document_name'), - data_source_type=item.get('data_source_type'), - segment_id=item.get('segment_id'), - score=item.get('score') if 'score' in item else None, - hit_count=item.get('hit_count') if 'hit_count' else None, - word_count=item.get('word_count') if 'word_count' in item else None, - segment_position=item.get('segment_position') if 'segment_position' in item else None, - index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, - content=item.get('content'), - retriever_from=item.get('retriever_from'), - created_by=self._user_id + position=item.get("position"), + dataset_id=item.get("dataset_id"), + dataset_name=item.get("dataset_name"), + document_id=item.get("document_id"), + document_name=item.get("document_name"), + data_source_type=item.get("data_source_type"), + segment_id=item.get("segment_id"), + score=item.get("score") if "score" in item else None, + hit_count=item.get("hit_count") if "hit_count" in item else None, + word_count=item.get("word_count") if "word_count" in item else None, + segment_position=item.get("segment_position") if "segment_position" in item else None, + index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, + content=item.get("content"), + retriever_from=item.get("retriever_from"), + created_by=self._user_id, ) db.session.add(dataset_retriever_resource) db.session.commit() self._queue_manager.publish( - QueueRetrieverResourcesEvent(retriever_resources=resource), - PublishFrom.APPLICATION_MANAGER + QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 84bab7e1a3..8ac12f72f2 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): - """Callback Handler that prints to std out.""" \ No newline at end of file + """Callback Handler that prints to std out.""" diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index b7e0cc0c2b..4cc793b0d7 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider).first() + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider + ) + .first() + ) if embedding: text_embeddings[i] = embedding.get_embedding() else: @@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings): embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, - self._model_instance.credentials) - max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) for i in range(0, len(embedding_queue_texts), max_chunks): - batch_texts = embedding_queue_texts[i:i + max_chunks] + batch_texts = embedding_queue_texts[i : i + max_chunks] - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user) for vector in embedding_result.embeddings: try: @@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except Exception as e: - logging.exception('Failed transform embedding: ', e) + logging.exception("Failed transform embedding: ", e) cache_embeddings = [] try: for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): text_embeddings[i] = embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider, + ) embedding_cache.set_embedding(embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) @@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.error('Failed to embed documents: ', ex) + logger.error("Failed to embed documents: ", ex) raise ex return text_embeddings @@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings): """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) try: - embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user) embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() @@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except: - logging.exception('Failed to add embedding to redis') + logging.exception("Failed to add embedding to redis") return embedding_results diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 0cdf8670c4..656bf4aa72 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -2,7 +2,7 @@ from enum import Enum class PlanningStrategy(Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 370aeee463..10bc9f6ed7 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class PromptMessageFileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel): class ImagePromptMessageFile(PromptMessageFile): class DETAIL(enum.Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 22a21ecf93..9ed5528e43 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -12,6 +12,7 @@ class ModelStatus(Enum): """ Enum class for model status. """ + ACTIVE = "active" NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" @@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel): """ Simple provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_large=provider_entity.icon_large, - supported_model_types=provider_entity.supported_model_types + supported_model_types=provider_entity.supported_model_types, ) @@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel): """ Model class for model response. """ + status: ModelStatus load_balancing_enabled: bool = False @@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ + provider: SimpleModelProviderEntity @@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel): """ Default model provider entity. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: DefaultModelProviderEntity diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 778ef2e1ac..4797b69b85 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ + tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType @@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel): original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any(len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations) - and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): + if ( + any( + len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations + ) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods + ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: @@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel): if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: - raise ValueError(f'Model {model} is disabled.') + raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] @@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel): copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: - if (restrict_model.model_type == model_type - and restrict_model.model == model - and restrict_model.base_model_name): - copy_credentials['base_model_name'] = restrict_model.base_model_name + if ( + restrict_model.model_type == model_type + and restrict_model.model == model + and restrict_model.base_model_name + ): + copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: @@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel): current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( - (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), - None + (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) - return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ - SystemConfigurationStatus.QUOTA_EXCEEDED + return ( + SystemConfigurationStatus.ACTIVE + if current_quota_configuration.is_valid + else SystemConfigurationStatus.QUOTA_EXCEEDED + ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ - return (self.custom_configuration.provider is not None - or len(self.custom_configuration.models) > 0) + return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ @@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel): return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [], ) def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: @@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [] ) if provider_record: @@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel): # fix origin data if provider_record.encrypted_config: if not provider_record.encrypted_config.startswith("{"): - original_credentials = { - "openai_api_key": provider_record.encrypted_config - } + original_credentials = {"openai_api_key": provider_record.encrypted_config} else: original_credentials = json.loads(provider_record.encrypted_config) else: @@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel): credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, - credentials=credentials + provider=self.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER ) provider_model_credentials_cache.delete() @@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # delete provider if provider_record: @@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: """ Get custom model credentials. @@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel): return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [], ) return None - def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> tuple[ProviderModel, dict]: + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [] ) if provider_model_record: try: - original_credentials = json.loads( - provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + original_credentials = ( + json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + ) except JSONDecodeError: original_credentials = {} @@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel): credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, - model_type=model_type, - model=model, - credentials=credentials + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in credentials.items(): @@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel): model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_model_record) db.session.commit() @@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # delete provider model if provider_model_record: @@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = True @@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=True + enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = False @@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=False + enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - return db.session.query(ProviderModelSetting) \ + return ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + load_balancing_config_count = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).count() + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .count() + ) if load_balancing_config_count <= 1: - raise ValueError('Model load balancing configuration must be more than 1.') + raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = True @@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=True + load_balancing_enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = False @@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=False + load_balancing_enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel): return # get preferred provider - preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + preferred_model_provider = ( + db.session.query(TenantPreferredModelProvider) .filter( - TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider - ).first() + TenantPreferredModelProvider.tenant_id == self.tenant_id, + TenantPreferredModelProvider.provider_name == self.provider.provider, + ) + .first() + ) if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value @@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value + preferred_provider_type=provider_type.value, ) db.session.add(preferred_model_provider) @@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel): :return: """ # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - credential_form_schemas - ) + credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() @@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel): return copy_credentials - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type @@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel): return None - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type @@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel): if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) if only_active: @@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel): # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_system_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_system_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get system provider models. @@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel): if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ - ConfigurateMethod.CUSTOMIZABLE_MODEL]: + ConfigurateMethod.CUSTOMIZABLE_MODEL + ]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() if restrict_model.base_model_name: - copy_credentials['base_model_name'] = restrict_model.base_model_name + copy_credentials["base_model_name"] = restrict_model.base_model_name try: - custom_model_schema = ( - provider_instance.get_model_instance(restrict_model.model_type) - .get_customizable_model_schema_from_credentials( - restrict_model.model, - copy_credentials - ) - ) + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel): continue status = ModelStatus.ACTIVE - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel): return provider_models - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_custom_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel): deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel): continue try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials ) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE load_balancing_enabled = False - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel): deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ + tenant_id: str configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: + def get_models( + self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel): """ Provider model bundle. """ + configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel # pydantic configs - model_config = ConfigDict(arbitrary_types_allowed=True, - protected_namespaces=()) + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 6c616d1aad..88b16f13af 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -10,18 +10,19 @@ from models.provider import ProviderQuotaType class QuotaUnit(Enum): - TIMES = 'times' - TOKENS = 'tokens' - CREDITS = 'credits' + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" class SystemConfigurationStatus(Enum): """ Enum class for system configuration status. """ - ACTIVE = 'active' - QUOTA_EXCEEDED = 'quota-exceeded' - UNSUPPORTED = 'unsupported' + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" class RestrictModel(BaseModel): @@ -37,6 +38,7 @@ class QuotaConfiguration(BaseModel): """ Model class for provider quota configuration. """ + quota_type: ProviderQuotaType quota_unit: QuotaUnit quota_limit: int @@ -49,6 +51,7 @@ class SystemConfiguration(BaseModel): """ Model class for provider system configuration. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -59,6 +62,7 @@ class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. """ + credentials: dict @@ -66,6 +70,7 @@ class CustomModelConfiguration(BaseModel): """ Model class for provider custom model configuration. """ + model: str model_type: ModelType credentials: dict @@ -78,6 +83,7 @@ class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. """ + provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] @@ -86,6 +92,7 @@ class ModelLoadBalancingConfiguration(BaseModel): """ Class for model load balancing configuration. """ + id: str name: str credentials: dict @@ -95,6 +102,7 @@ class ModelSettings(BaseModel): """ Model class for model settings. """ + model: str model_type: ModelType enabled: bool = True @@ -103,6 +111,7 @@ class ModelSettings(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + class BasicProviderConfig(BaseModel): """ Base model class for common provider settings like credentials @@ -131,6 +140,7 @@ class BasicProviderConfig(BaseModel): type: Type = Field(..., description="The type of the credentials") name: str = Field(..., description="The name of the credentials") + class ProviderConfig(BasicProviderConfig): """ Model class for common provider settings like credentials diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 53323a2eeb..3b186476eb 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -3,6 +3,7 @@ from typing import Optional class LLMError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -11,6 +12,7 @@ class LLMError(Exception): class LLMBadRequestError(LLMError): """Raised when the LLM returns bad request.""" + description = "Bad Request" @@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception): """ Custom exception raised when the provider token is not initialized. """ + description = "Provider Token Not Init" def __init__(self, *args, **kwargs): @@ -28,6 +31,7 @@ class QuotaExceededError(Exception): """ Custom exception raised when the quota for a provider has been exceeded. """ + description = "Quota Exceeded" @@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception): """ Custom exception raised when the quota for an app has been exceeded. """ + description = "App Invoke Quota Exceeded" @@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception): """ Custom exception raised when the model not support """ + description = "Model Currently Not Support" class InvokeRateLimitError(Exception): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 4db7a99973..38cebb6b6b 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -20,10 +20,7 @@ class APIBasedExtensionRequestor: :param params: the request params :return: the response json """ - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer {}".format(self.api_key) - } + headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} url = self.api_endpoint @@ -32,20 +29,17 @@ class APIBasedExtensionRequestor: proxies = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxies = { - 'http': dify_config.SSRF_PROXY_HTTP_URL, - 'https': dify_config.SSRF_PROXY_HTTPS_URL, + "http": dify_config.SSRF_PROXY_HTTP_URL, + "https": dify_config.SSRF_PROXY_HTTPS_URL, } response = requests.request( - method='POST', + method="POST", url=url, - json={ - 'point': point.value, - 'params': params - }, + json={"point": point.value, "params": params}, headers=headers, timeout=self.timeout, - proxies=proxies + proxies=proxies, ) except requests.exceptions.Timeout: raise ValueError("request timeout") @@ -53,9 +47,8 @@ class APIBasedExtensionRequestor: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}, content: {}".format( - response.status_code, - response.text[:100] - )) + raise ValueError( + "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) + ) return response.json() diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 8d73aa2b8b..97dbaf2026 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -3,6 +3,7 @@ import importlib.util import json import logging import os +from pathlib import Path from typing import Any, Optional from pydantic import BaseModel @@ -11,8 +12,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map class ExtensionModule(enum.Enum): - MODERATION = 'moderation' - EXTERNAL_DATA_TOOL = 'external_data_tool' + MODERATION = "moderation" + EXTERNAL_DATA_TOOL = "external_data_tool" class ModuleExtension(BaseModel): @@ -41,12 +42,12 @@ class Extensible: position_map = {} # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") current_dir_path = os.path.dirname(current_path) # traverse subdirectories for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith('__'): + if subdir_name.startswith("__"): continue subdir_path = os.path.join(current_dir_path, subdir_name) @@ -58,21 +59,20 @@ class Extensible: # in the front-end page and business logic, there are special treatments. builtin = False position = None - if '__builtin__' in file_names: + if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, '__builtin__') + builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): - with open(builtin_file_path, encoding='utf-8') as f: - position = int(f.read().strip()) + position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position_map[extension_name] = position - if (extension_name + '.py') not in file_names: + if (extension_name + ".py") not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") continue # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + '.py') + py_path = os.path.join(subdir_path, extension_name + ".py") spec = importlib.util.spec_from_file_location(extension_name, py_path) if not spec or not spec.loader: raise Exception(f"Failed to load module {extension_name} from {py_path}") @@ -91,25 +91,29 @@ class Extensible: json_data = {} if not builtin: - if 'schema.json' not in file_names: + if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, 'schema.json') + json_path = os.path.join(subdir_path, "schema.json") json_data = {} if os.path.exists(json_path): - with open(json_path, encoding='utf-8') as f: + with open(json_path, encoding="utf-8") as f: json_data = json.load(f) - extensions.append(ModuleExtension( - extension_class=extension_class, - name=extension_name, - label=json_data.get('label'), - form_schema=json_data.get('form_schema'), - builtin=builtin, - position=position - )) + extensions.append( + ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get("label"), + form_schema=json_data.get("form_schema"), + builtin=builtin, + position=position, + ) + ) - sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) + sorted_extensions = sort_to_dict_by_position_map( + position_map=position_map, data=extensions, name_func=lambda x: x.name + ) return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 29e892c58a..3da170455e 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -6,10 +6,7 @@ from core.moderation.base import Moderation class Extension: __module_extensions: dict[str, dict[str, ModuleExtension]] = {} - module_classes = { - ExtensionModule.MODERATION: Moderation, - ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool - } + module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool} def init(self): for module, module_class in self.module_classes.items(): diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 58c82502ea..54ec97a493 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool): raise ValueError("api_based_extension_id is required") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool): api_based_extension_id = self.config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == self.tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(self.variable)) + raise ValueError( + "[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid".format(self.variable) + ) # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=self.tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key) try: # request api - requestor = APIBasedExtensionRequestor( - api_endpoint=api_based_extension.api_endpoint, - api_key=api_key - ) + requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) except Exception as e: - raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( - self.variable, - e - )) + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) - response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ - 'app_id': self.app_id, - 'tool_variable': self.variable, - 'inputs': inputs, - 'query': query - }) + response_json = requestor.request( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query}, + ) - if 'result' not in response_json: - raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" - .format(self.variable)) + if "result" not in response_json: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result not found in response".format( + self.variable + ) + ) - if not isinstance(response_json['result'], str): - raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string" - .format(self.variable)) + if not isinstance(response_json["result"], str): + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + ) - return response_json['result'] + return response_json["result"] diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 8601cb34e7..84b94e117f 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -12,11 +12,14 @@ logger = logging.getLogger(__name__) class ExternalDataFetch: - def fetch(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fetch( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -38,7 +41,7 @@ class ExternalDataFetch: app_id, tool, inputs, - query + query, ) futures[future] = tool @@ -50,12 +53,15 @@ class ExternalDataFetch: inputs.update(results) return inputs - def _query_external_data_tool(self, flask_app: Flask, - tenant_id: str, - app_id: str, - external_data_tool: ExternalDataVariableEntity, - inputs: dict, - query: str) -> tuple[Optional[str], Optional[str]]: + def _query_external_data_tool( + self, + flask_app: Flask, + tenant_id: str, + app_id: str, + external_data_tool: ExternalDataVariableEntity, + inputs: dict, + query: str, + ) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app @@ -72,17 +78,10 @@ class ExternalDataFetch: tool_config = external_data_tool.config external_data_tool_factory = ExternalDataToolFactory( - name=tool_type, - tenant_id=tenant_id, - app_id=app_id, - variable=tool_variable, - config=tool_config + name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config ) # query external data tool - result = external_data_tool_factory.query( - inputs=inputs, - query=query - ) + result = external_data_tool_factory.query(inputs=inputs, query=query) return tool_variable, result diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 979f243af6..2872109859 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( - tenant_id=tenant_id, - app_id=app_id, - variable=variable, - config=config + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 3959f4b4a0..5c4e694025 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel): """ File Upload Entity. """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -28,9 +29,9 @@ class FileType(enum.Enum): class FileTransferMethod(enum.Enum): - REMOTE_URL = 'remote_url' - LOCAL_FILE = 'local_file' - TOOL_FILE = 'tool_file' + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" @staticmethod def value_of(value): @@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum): return member raise ValueError(f"No matching enum found for value '{value}'") + class FileBelongsTo(enum.Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" @staticmethod def value_of(value): @@ -65,16 +67,16 @@ class FileVar(BaseModel): def to_dict(self) -> dict: return { - '__variant': self.__class__.__name__, - 'tenant_id': self.tenant_id, - 'type': self.type.value, - 'transfer_method': self.transfer_method.value, - 'url': self.preview_url, - 'remote_url': self.url, - 'related_id': self.related_id, - 'filename': self.filename, - 'extension': self.extension, - 'mime_type': self.mime_type, + "__variant": self.__class__.__name__, + "tenant_id": self.tenant_id, + "type": self.type.value, + "transfer_method": self.transfer_method.value, + "url": self.preview_url, + "remote_url": self.url, + "related_id": self.related_id, + "filename": self.filename, + "extension": self.extension, + "mime_type": self.mime_type, } def to_markdown(self) -> str: @@ -86,7 +88,7 @@ class FileVar(BaseModel): if self.type == FileType.IMAGE: text = f'![{self.filename or ""}]({preview_url})' else: - text = f'[{self.filename or preview_url}]({preview_url})' + text = f"[{self.filename or preview_url}]({preview_url})" return text @@ -115,28 +117,29 @@ class FileVar(BaseModel): return ImagePromptMessageContent( data=self.data, detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW + if image_config.get("detail") == "high" + else ImagePromptMessageContent.DETAIL.LOW, ) def _get_data(self, force_url: bool = False) -> Optional[str]: from models.model import UploadFile + if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == self.related_id, - UploadFile.tenant_id == self.tenant_id - ).first()) - - return UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=force_url + upload_file = ( + db.session.query(UploadFile) + .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) + .first() ) + + return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) elif self.transfer_method == FileTransferMethod.TOOL_FILE: extension = self.extension # add sign url - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=extension + ) return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 085ff07cfd..83059b216e 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, - user: Union[Account, EndUser]) -> list[FileVar]: + def validate_and_transform_files_arg( + self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] + ) -> list[FileVar]: """ validate and transform files arg @@ -30,22 +30,22 @@ class MessageFileParser: """ for file in files: if not isinstance(file, dict): - raise ValueError('Invalid file format, must be dict') - if not file.get('type'): - raise ValueError('Missing file type') - FileType.value_of(file.get('type')) - if not file.get('transfer_method'): - raise ValueError('Missing file transfer method') - FileTransferMethod.value_of(file.get('transfer_method')) - if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: - if not file.get('url'): - raise ValueError('Missing file url') - if not file.get('url').startswith('http'): - raise ValueError('Invalid file url') - if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): - raise ValueError('Missing file upload_file_id') - if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): - raise ValueError('Missing file tool_file_id') + raise ValueError("Invalid file format, must be dict") + if not file.get("type"): + raise ValueError("Missing file type") + FileType.value_of(file.get("type")) + if not file.get("transfer_method"): + raise ValueError("Missing file transfer method") + FileTransferMethod.value_of(file.get("transfer_method")) + if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: + if not file.get("url"): + raise ValueError("Missing file url") + if not file.get("url").startswith("http"): + raise ValueError("Invalid file url") + if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): + raise ValueError("Missing file upload_file_id") + if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): + raise ValueError("Missing file tool_file_id") # transform files to file objs type_file_objs = self._to_file_objs(files, file_extra_config) @@ -62,17 +62,17 @@ class MessageFileParser: continue # Validate number of files - if len(files) > image_config['number_limits']: + if len(files) > image_config["number_limits"]: raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") for file_obj in file_objs: # Validate transfer method - if file_obj.transfer_method.value not in image_config['transfer_methods']: - raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') + if file_obj.transfer_method.value not in image_config["transfer_methods"]: + raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") # Validate file type if file_obj.type != FileType.IMAGE: - raise ValueError(f'Invalid file type: {file_obj.type}') + raise ValueError(f"Invalid file type: {file_obj.type}") if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: # check remote url valid and is image @@ -81,18 +81,21 @@ class MessageFileParser: raise ValueError(error) elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: # get upload file from upload_file_id - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - UploadFile.extension.in_(IMAGE_EXTENSIONS) - ).first()) + upload_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == file_obj.related_id, + UploadFile.tenant_id == self.tenant_id, + UploadFile.created_by == user.id, + UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + UploadFile.extension.in_(IMAGE_EXTENSIONS), + ) + .first() + ) # check upload file is belong to tenant and user if not upload_file: - raise ValueError('Invalid upload file') + raise ValueError("Invalid upload file") new_files.append(file_obj) @@ -113,8 +116,9 @@ class MessageFileParser: # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: + def _to_file_objs( + self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig + ) -> dict[FileType, list[FileVar]]: """ transform files to file objs @@ -152,23 +156,23 @@ class MessageFileParser: :return: """ if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) if transfer_method != FileTransferMethod.TOOL_FILE: return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config + url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config, ) return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, url=None, - related_id=file.get('tool_file_id'), - extra_config=file_extra_config + related_id=file.get("tool_file_id"), + extra_config=file_extra_config, ) else: return FileVar( @@ -178,29 +182,30 @@ class MessageFileParser: transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, related_id=file.upload_file_id or None, - extra_config=file_extra_config + extra_config=file_extra_config, ) def _check_image_remote_url(self, url): try: headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" } def is_s3_presigned_url(url): try: parsed_url = urlparse(url) - if 'amazonaws.com' not in parsed_url.netloc: + if "amazonaws.com" not in parsed_url.netloc: return False query_params = parse_qs(parsed_url.query) - required_params = ['Signature', 'Expires'] + required_params = ["Signature", "Expires"] for param in required_params: if param not in query_params: return False - if not query_params['Expires'][0].isdigit(): + if not query_params["Expires"][0].isdigit(): return False - signature = query_params['Signature'][0] - if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): + signature = query_params["Signature"][0] + if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): return False return True except Exception: diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 98226e89c0..4d113a9cc2 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -7,7 +7,8 @@ tool_file_manager: dict[str, Any] = { 'manager': None } + class ToolFileParser: @staticmethod - def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file + def get_tool_file_manager() -> "ToolFileManager": + return tool_file_manager["manager"] diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 737a11e426..a8c1fd4d02 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -9,7 +9,7 @@ from typing import Optional from configs import dify_config from extensions.ext_storage import storage -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) @@ -22,18 +22,18 @@ class UploadFileParser: if upload_file.extension not in IMAGE_EXTENSIONS: return None - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 try: data = storage.load(upload_file.key) except FileNotFoundError: - logging.error(f'File not found: {upload_file.key}') + logging.error(f"File not found: {upload_file.key}") return None - encoded_string = base64.b64encode(data).decode('utf-8') - return f'data:{upload_file.mime_type};base64,{encoded_string}' + encoded_string = base64.b64encode(data).decode("utf-8") + return f"data:{upload_file.mime_type};base64,{encoded_string}" @classmethod def get_signed_temp_image_url(cls, upload_file_id) -> str: @@ -44,7 +44,7 @@ class UploadFileParser: :return: """ base_url = dify_config.FILES_URL - image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' + image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4662ebb47a..4932284540 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) -class CodeExecutionException(Exception): + +class CodeExecutionError(Exception): pass + class CodeExecutionResponse(BaseModel): class Data(BaseModel): stdout: Optional[str] = None @@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel): class CodeLanguage(str, Enum): - PYTHON3 = 'python3' - JINJA2 = 'jinja2' - JAVASCRIPT = 'javascript' + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" class CodeExecutor: @@ -45,71 +47,74 @@ class CodeExecutor: } code_language_to_running_language = { - CodeLanguage.JAVASCRIPT: 'nodejs', + CodeLanguage.JAVASCRIPT: "nodejs", CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, } - supported_dependencies_languages: set[CodeLanguage] = { - CodeLanguage.PYTHON3 - } + supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} @classmethod - def execute_code(cls, - language: CodeLanguage, - preload: str, - code: str) -> str: + def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" - headers = { - 'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY - } + headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} data = { - 'language': cls.code_language_to_running_language.get(language), - 'code': code, - 'preload': preload, - 'enable_network': True + "language": cls.code_language_to_running_language.get(language), + "code": code, + "preload": preload, + "enable_network": True, } try: - response = post(str(url), json=data, headers=headers, - timeout=Timeout( - connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, - read=dify_config.CODE_EXECUTION_READ_TIMEOUT, - write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, - pool=None)) + response = post( + str(url), + json=data, + headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ), + ) if response.status_code == 503: - raise CodeExecutionException('Code execution service is unavailable') + raise CodeExecutionError("Code execution service is unavailable") elif response.status_code != 200: - raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running') - except CodeExecutionException as e: + raise Exception( + f"Failed to execute code, got status code {response.status_code}," + f" please check if the sandbox service is running" + ) + except CodeExecutionError as e: raise e except Exception as e: - raise CodeExecutionException('Failed to execute code, which is likely a network issue,' - ' please check if the sandbox service is running.' - f' ( Error: {str(e)} )') + raise CodeExecutionError( + "Failed to execute code, which is likely a network issue," + " please check if the sandbox service is running." + f" ( Error: {str(e)} )" + ) try: response = response.json() except: - raise CodeExecutionException('Failed to parse response') + raise CodeExecutionError("Failed to parse response") - if (code := response.get('code')) != 0: - raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") response = CodeExecutionResponse(**response) if response.data.error: - raise CodeExecutionException(response.data.error) + raise CodeExecutionError(response.data.error) - return response.data.stdout or '' + return response.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: @@ -122,13 +127,13 @@ class CodeExecutor: """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f'Unsupported language {language}') + raise CodeExecutionError(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) try: response = cls.execute_code(language, preload, runner) - except CodeExecutionException as e: + except CodeExecutionError as e: raise e return template_transformer.transform_response(response) diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 3f099b7ac5..e233a596b9 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel): return { "type": "code", "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], + "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], "code_language": cls.get_language(), "code": cls.get_default_code(), - "outputs": { - "result": { - "type": "string", - "children": None - } - } - } + "outputs": {"result": {"type": "string", "children": None}}, + }, } diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index a157fcc6d1..ae324b83a9 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider): result: arg1 + arg2 } } - """) + """ + ) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index a4d2551972..d67a0903aa 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer): var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` console.log(result) - """) + """ + ) return runner_script diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index f1e5da584c..db2eb5ebb6 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -10,8 +10,6 @@ class Jinja2Formatter: :param inputs: inputs :return: """ - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=inputs - ) + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - return result['result'] + return result["result"] diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index b8cb29600e..63d58edbc7 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): :param response: response :return: """ - return { - 'result': cls.extract_result_str_from_response(response) - } + return {"result": cls.extract_result_str_from_response(response)} @classmethod def get_runner_script(cls) -> str: diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 923724b49d..9cca8af7c6 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider): return { "result": arg1 + arg2, } - """) + """ + ) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index cf66558b65..6f016f27bc 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,9 +5,9 @@ from base64 import b64encode class TemplateTransformer(ABC): - _code_placeholder: str = '{{code}}' - _inputs_placeholder: str = '{{inputs}}' - _result_tag: str = '<>' + _code_placeholder: str = "{{code}}" + _inputs_placeholder: str = "{{inputs}}" + _result_tag: str = "<>" @classmethod def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: @@ -24,9 +24,9 @@ class TemplateTransformer(ABC): @classmethod def extract_result_str_from_response(cls, response: str) -> str: - result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL) + result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: - raise ValueError('Failed to parse result') + raise ValueError("Failed to parse result") result = result.group(1) return result @@ -50,7 +50,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: dict) -> str: inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() - input_base64_encoded = b64encode(inputs_json_str).decode('utf-8') + input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded @classmethod @@ -67,4 +67,4 @@ class TemplateTransformer(ABC): """ Get preload script """ - return '' + return "" diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 5e5deb86b4..96341a1b78 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -8,14 +8,15 @@ def obfuscated_token(token: str): if not token: return token if len(token) <= 8: - return '*' * 20 - return token[:6] + '*' * 12 + token[-2:] + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): - raise ValueError(f'Tenant with id {tenant_id} not found') + raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 29cb4acc7d..5e274f8916 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -25,7 +25,7 @@ class ProviderCredentialsCache: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 20feae8554..b880590de2 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -12,19 +12,20 @@ logger = logging.getLogger(__name__) def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config - if (moderation_config and moderation_config.enabled is True - and 'openai' in hosting_configuration.provider_map - and hosting_configuration.provider_map['openai'].enabled is True + if ( + moderation_config + and moderation_config.enabled is True + and "openai" in hosting_configuration.provider_map + and hosting_configuration.provider_map["openai"].enabled is True ): using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type provider_name = model_config.provider - if using_provider_type == ProviderType.SYSTEM \ - and provider_name in moderation_config.providers: - hosting_openai_config = hosting_configuration.provider_map['openai'] + if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: + hosting_openai_config = hosting_configuration.provider_map["openai"] # 2000 text per chunk length = 2000 - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] if len(text_chunks) == 0: return True @@ -34,15 +35,13 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() moderation_result = model_type_instance.invoke( - model='text-moderation-stable', - credentials=hosting_openai_config.credentials, - text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk ) if moderation_result is True: return True except Exception as ex: logger.exception(ex) - raise InvokeBadRequestError('Rate limit exceeded, please try again later.') + raise InvokeBadRequestError("Rate limit exceeded, please try again later.") return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 2000577a40..e6e1491548 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -37,8 +37,9 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] """ Get all the subclasses of the parent type from the module """ - classes = [x for _, x in vars(mod).items() - if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)] + classes = [ + x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type) + ] return classes @@ -56,6 +57,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") case _: - raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') \ No newline at end of file + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 32e3806231..3efdc8aa47 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -73,10 +73,10 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) def is_filtered( - include_set: set[str], - exclude_set: set[str], - data: Any, - name_func: Callable[[Any], str], + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], ) -> bool: """ Check if the object should be filtered out. @@ -102,9 +102,9 @@ def is_filtered( def sort_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> list[Any]: """ Sort the objects by the position map. @@ -117,13 +117,13 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) def sort_to_dict_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> OrderedDict[str, Any]: """ Sort the objects into a ordered dict by the position map. diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 14ca8e943c..4e6d58904e 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,31 +1,34 @@ """ Proxy requests to avoid SSRF """ + import logging import os import time import httpx -SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') -SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') -SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') -SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) +SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") +SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") +SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -proxies = { - 'http://': SSRF_PROXY_HTTP_URL, - 'https://': SSRF_PROXY_HTTPS_URL -} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +proxies = ( + {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} + if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL + else None +) BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects - + retries = 0 while retries <= max_retries: try: @@ -52,24 +55,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('GET', url, max_retries=max_retries, **kwargs) + return make_request("GET", url, max_retries=max_retries, **kwargs) def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('POST', url, max_retries=max_retries, **kwargs) + return make_request("POST", url, max_retries=max_retries, **kwargs) def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PUT', url, max_retries=max_retries, **kwargs) + return make_request("PUT", url, max_retries=max_retries, **kwargs) def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PATCH', url, max_retries=max_retries, **kwargs) + return make_request("PATCH", url, max_retries=max_retries, **kwargs) def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('DELETE', url, max_retries=max_retries, **kwargs) + return make_request("DELETE", url, max_retries=max_retries, **kwargs) def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('HEAD', url, max_retries=max_retries, **kwargs) + return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index a6f486e81d..e848b46c56 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -9,15 +9,15 @@ from extensions.ext_redis import redis_client class ToolParameterCacheType(Enum): PARAMETER = "tool_parameter" + class ToolParameterCache: - def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, - cache_type: ToolParameterCacheType, - identity_id: str - ): - self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}" + def __init__( + self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str + ): + self.cache_key = ( + f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f":identity_id:{identity_id}" + ) def get(self) -> Optional[dict]: """ @@ -28,7 +28,7 @@ class ToolParameterCache: cached_tool_parameter = redis_client.get(self.cache_key) if cached_tool_parameter: try: - cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = cached_tool_parameter.decode("utf-8") cached_tool_parameter = json.loads(cached_tool_parameter) except JSONDecodeError: return None @@ -52,4 +52,4 @@ class ToolParameterCache: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 2777367963..2cc5d89727 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -10,6 +10,7 @@ class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" ENDPOINT = "endpoint" + class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" @@ -23,7 +24,7 @@ class ToolProviderCredentialsCache: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None @@ -47,4 +48,4 @@ class ToolProviderCredentialsCache: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index ddcd751286..eeeccc2349 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -46,7 +46,7 @@ class HostingConfiguration: def init_app(self, app: Flask) -> None: config = app.config - if config.get('EDITION') != 'CLOUD': + if config.get("EDITION") != "CLOUD": return self.provider_map["azure_openai"] = self.init_azure_openai(config) @@ -65,7 +65,7 @@ class HostingConfiguration: credentials = { "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), - "base_model_name": "gpt-35-turbo" + "base_model_name": "gpt-35-turbo", } quotas = [] @@ -77,26 +77,45 @@ class HostingConfiguration: RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM), RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM), RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), + RestrictModel( + model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM + ), RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), - RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), - ] + RestrictModel( + model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM + ), + RestrictModel( + model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM + ), + RestrictModel( + model="text-embedding-ada-002", + base_model_name="text-embedding-ada-002", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-small", + base_model_name="text-embedding-3-small", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-large", + base_model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], ) quotas.append(trial_quota) - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -110,17 +129,12 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=trial_models - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") - paid_quota = PaidHostingQuota( - restrict_models=paid_models - ) + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: @@ -134,12 +148,7 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -153,9 +162,7 @@ class HostingConfiguration: if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): @@ -170,12 +177,7 @@ class HostingConfiguration: if app_config.get("HOSTED_ANTHROPIC_API_BASE"): credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -192,7 +194,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -210,7 +212,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -228,7 +230,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -238,21 +240,19 @@ class HostingConfiguration: @staticmethod def init_moderation_config(app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") \ - and app_config.get("HOSTED_MODERATION_PROVIDERS"): + if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( - enabled=True, - providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',') + enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") ) - return HostedModerationConfig( - enabled=False - ) + return HostedModerationConfig(enabled=False) @staticmethod def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: models_str = app_config.get(env_var) models_list = models_str.split(",") if models_str else [] - return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if - model_name.strip()] - + return [ + RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) + for model_name in models_list + if model_name.strip() + ] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index df563f609b..af20df41b1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -39,7 +39,6 @@ from services.feature_service import FeatureService class IndexingRunner: - def __init__(self): self.storage = storage self.model_manager = ModelManager() @@ -49,25 +48,26 @@ class IndexingRunner: for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) @@ -76,20 +76,20 @@ class IndexingRunner: index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, - documents=documents + documents=documents, ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: - logging.warning('Document deleted, document id: {}'.format(dataset_document.id)) + logging.warning("Document deleted, document id: {}".format(dataset_document.id)) except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -98,26 +98,25 @@ class IndexingRunner: """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() for document_segment in document_segments: db.session.delete(document_segment) db.session.commit() # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -125,28 +124,26 @@ class IndexingRunner: text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) # load self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -155,17 +152,14 @@ class IndexingRunner: """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() documents = [] @@ -180,42 +174,48 @@ class IndexingRunner: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) documents.append(document) # build index # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate( + self, + tenant_id: str, + extract_settings: list[ExtractSetting], + tmp_processing_rule: dict, + doc_form: str = None, + doc_language: str = "English", + dataset_id: str = None, + indexing_technique: str = "economy", + ) -> dict: """ Estimate the indexing for the document. """ @@ -229,18 +229,16 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + raise ValueError("Dataset not found.") + if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -248,7 +246,7 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -263,8 +261,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) + mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter @@ -272,9 +269,7 @@ class IndexingRunner: # split to documents documents = self._split_to_documents_for_estimate( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule + text_docs=text_docs, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) @@ -282,110 +277,110 @@ class IndexingRunner: if len(preview_texts) < 5: preview_texts.append(document.page_content) - if doc_form and doc_form == 'qa_model': - + if doc_form and doc_form == "qa_model": if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) + response = LLMGenerator.generate_qa_document( + current_user.current_tenant_id, preview_texts[0], doc_language + ) document_qa_list = self.format_split_text(response) - return { - "total_segments": total_segments * 20, - "qa_preview": document_qa_list, - "preview": preview_texts - } - return { - "total_segments": total_segments, - "preview": preview_texts - } + return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} + return {"total_segments": total_segments, "preview": preview_texts} - def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ - -> list[Document]: + def _extract( + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + ) -> list[Document]: # load file - if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: + if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: return [] data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == 'upload_file': - if not data_source_info or 'upload_file_id' not in data_source_info: + if dataset_document.data_source_type == "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + ) if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=dataset_document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'notion_import': - if (not data_source_info or 'notion_workspace_id' not in data_source_info - or 'notion_page_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], "document": dataset_document, - "tenant_id": dataset_document.tenant_id + "tenant_id": dataset_document.tenant_id, }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'website_crawl': - if (not data_source_info or 'provider' not in data_source_info - or 'url' not in data_source_info or 'job_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): raise ValueError("no website import info found") extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], "tenant_id": dataset_document.tenant_id, - "url": data_source_info['url'], - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata['document_id'] = dataset_document.id - text_doc.metadata['dataset_id'] = dataset_document.dataset_id + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @staticmethod def filter_string(text): - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) return text @staticmethod - def _get_splitter(processing_rule: DatasetProcessRule, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -399,10 +394,10 @@ class IndexingRunner: separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") - if segmentation.get('chunk_overlap'): - chunk_overlap = segmentation['chunk_overlap'] + if segmentation.get("chunk_overlap"): + chunk_overlap = segmentation["chunk_overlap"] else: chunk_overlap = 0 @@ -411,22 +406,27 @@ class IndexingRunner: chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter - def _step_split(self, text_docs: list[Document], splitter: TextSplitter, - dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> list[Document]: + def _step_split( + self, + text_docs: list[Document], + splitter: TextSplitter, + dataset: Dataset, + dataset_document: DatasetDocument, + processing_rule: DatasetProcessRule, + ) -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -436,14 +436,12 @@ class IndexingRunner: processing_rule=processing_rule, tenant_id=dataset.tenant_id, document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language + document_language=dataset_document.doc_language, ) # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -457,7 +455,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -465,15 +463,21 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) return documents - def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> list[Document]: + def _split_to_documents( + self, + text_docs: list[Document], + splitter: TextSplitter, + processing_rule: DatasetProcessRule, + tenant_id: str, + document_form: str, + document_language: str, + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -488,12 +492,11 @@ class IndexingRunner: documents = splitter.split_documents([text_doc]) split_documents = [] for document_node in documents: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -506,15 +509,21 @@ class IndexingRunner: split_documents.append(document_node) all_documents.extend(split_documents) # processing qa document - if document_form == 'qa_model': + if document_form == "qa_model": for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents, - 'document_language': document_language}) + document_format_thread = threading.Thread( + target=self.format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": tenant_id, + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": document_language, + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -533,12 +542,14 @@ class IndexingRunner: document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy()) + qa_document = Document( + page_content=result["question"], metadata=document_node.metadata.model_copy() + ) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -546,8 +557,9 @@ class IndexingRunner: all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> list[Document]: + def _split_to_documents_for_estimate( + self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -567,8 +579,8 @@ class IndexingRunner: doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) - document.metadata['doc_id'] = doc_id - document.metadata['doc_hash'] = hash + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -586,23 +598,23 @@ class IndexingRunner: else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} - if 'pre_processing_rules' in rules: + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text @@ -611,27 +623,26 @@ class IndexingRunner: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] - def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, - dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + dataset_document: DatasetDocument, + documents: list[Document], + ) -> None: """ insert index and update document/segment status to completed """ embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # chunk nodes by chunk size @@ -640,18 +651,27 @@ class IndexingRunner: chunk_size = 10 # create keyword index - create_keyword_thread = threading.Thread(target=self._process_keyword_index, - args=(current_app._get_current_object(), - dataset.id, dataset_document.id, documents)) + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + ) create_keyword_thread.start() - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i:i + chunk_size] - futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, - chunk_documents, dataset, - dataset_document, embedding_model_instance)) + chunk_documents = documents[i : i + chunk_size] + futures.append( + executor.submit( + self._process_chunk, + current_app._get_current_object(), + index_processor, + chunk_documents, + dataset, + dataset_document, + embedding_model_instance, + ) + ) for future in futures: tokens += future.result() @@ -668,7 +688,7 @@ class IndexingRunner: DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, - } + }, ) @staticmethod @@ -679,23 +699,26 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != 'high_quality': - document_ids = [document.metadata['doc_id'] for document in documents] + if dataset.indexing_technique != "high_quality": + document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() - def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance): + def _process_chunk( + self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + ): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -703,26 +726,26 @@ class IndexingRunner: tokens = 0 if embedding_model_instance: tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens( - [document.page_content] - ) + embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) for document in chunk_documents ) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) - document_ids = [document.metadata['doc_id'] for document in chunk_documents] + document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() @@ -730,27 +753,26 @@ class IndexingRunner: @staticmethod def _check_document_paused_status(document_id: str): - indexing_cache_key = 'document_{}_is_paused'.format(document_id) + indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: - raise DocumentIsPausedException() + raise DocumentIsPausedError() @staticmethod - def _update_document_index_status(document_id: str, after_indexing_status: str, - extra_update_params: Optional[dict] = None) -> None: + def _update_document_index_status( + document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + ) -> None: """ Update the document indexing status. """ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: - raise DocumentIsPausedException() + raise DocumentIsPausedError() document = DatasetDocument.query.filter_by(id=document_id).first() if not document: - raise DocumentIsDeletedPausedException() + raise DocumentIsDeletedPausedError() - update_params = { - DatasetDocument.indexing_status: after_indexing_status - } + update_params = {DatasetDocument.indexing_status: after_indexing_status} if extra_update_params: update_params.update(extra_update_params) @@ -780,7 +802,7 @@ class IndexingRunner: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index @@ -788,17 +810,23 @@ class IndexingRunner: index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) - def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, - text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]: + def _transform( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + text_docs: list[Document], + doc_language: str, + process_rule: dict, + ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -806,18 +834,20 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) - documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, - process_rule=process_rule, tenant_id=dataset.tenant_id, - doc_language=doc_language) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=process_rule, + tenant_id=dataset.tenant_id, + doc_language=doc_language, + ) return documents def _load_segments(self, dataset, dataset_document, documents): # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -831,7 +861,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -839,15 +869,15 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) pass -class DocumentIsPausedException(Exception): +class DocumentIsPausedError(Exception): pass -class DocumentIsDeletedPausedException(Exception): +class DocumentIsDeletedPausedError(Exception): pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8c13b4a45c..78a6d6e683 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -43,21 +43,16 @@ class LLMGenerator: with measure_time() as timer: response = model_instance.invoke_llm( - prompt_messages=prompts, - model_parameters={ - "max_tokens": 100, - "temperature": 1 - }, - stream=False + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False ) answer = response.message.content - cleaned_answer = re.sub(r'^.*(\{.*\}).*$', r'\1', answer, flags=re.DOTALL) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) result_dict = json.loads(cleaned_answer) - answer = result_dict['Your Output'] + answer = result_dict["Your Output"] name = answer.strip() if len(name) > 75: - name = name[:75] + '...' + name = name[:75] + "..." # get tracing instance trace_manager = TraceQueueManager(app_id=app_id) @@ -79,14 +74,9 @@ class LLMGenerator: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - template="{{histories}}\n{{format_instructions}}\nquestions:\n" - ) + prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") - prompt = prompt_template.format({ - "histories": histories, - "format_instructions": format_instructions - }) + prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: model_manager = ModelManager() @@ -101,12 +91,7 @@ class LLMGenerator: try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - "max_tokens": 256, - "temperature": 0 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False ) questions = output_parser.parse(response.message.content) @@ -119,32 +104,24 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: + def generate_rule_config( + cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 + ) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" - rule_config = { - "prompt": "", - "variables": [], - "opening_statement": "", - "error": "" - } - model_parameters = { - "max_tokens": rule_config_max_tokens, - "temperature": 0.01 - } + rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} + model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} if no_variable: - prompt_template = PromptTemplateParser( - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE - ) + prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate)] @@ -158,13 +135,11 @@ class LLMGenerator: try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) rule_config["prompt"] = response.message.content - + except InvokeError as e: error = str(e) error_step = "generate rule config" @@ -179,24 +154,18 @@ class LLMGenerator: # get rule config prompt, parameter and statement prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - prompt_generate - ) + prompt_template = PromptTemplateParser(prompt_generate) - parameter_template = PromptTemplateParser( - parameter_generate - ) + parameter_template = PromptTemplateParser(parameter_generate) - statement_template = PromptTemplateParser( - statement_generate - ) + statement_template = PromptTemplateParser(statement_generate) # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] @@ -213,9 +182,7 @@ class LLMGenerator: try: # the first step to generate the task prompt prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) except InvokeError as e: error = str(e) @@ -230,7 +197,7 @@ class LLMGenerator: inputs={ "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] @@ -240,15 +207,13 @@ class LLMGenerator: "TASK_DESCRIPTION": instruction, "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False ) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) except InvokeError as e: @@ -257,9 +222,7 @@ class LLMGenerator: try: statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False ) rule_config["opening_statement"] = statement_content.message.content except InvokeError as e: @@ -284,18 +247,10 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [ - SystemPromptMessage(content=prompt), - UserPromptMessage(content=query) - ] + prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - 'temperature': 0.01, - "max_tokens": 2000 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False ) answer = response.message.content diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 6a60f8de80..1e743f1757 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserException(Exception): +class OutputParserError(Exception): pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index 8856f0c685..0c7683b16d 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,6 +1,6 @@ from typing import Any -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import ( RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -10,9 +10,12 @@ from libs.json_in_md_parser import parse_and_check_json_markdown class RuleConfigGeneratorOutputParser: - def get_format_instructions(self) -> tuple[str, str, str]: - return RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE + return ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) def parse(self, text: str) -> Any: try: @@ -21,16 +24,9 @@ class RuleConfigGeneratorOutputParser: if not isinstance(parsed["prompt"], str): raise ValueError("Expected 'prompt' to be a string.") if not isinstance(parsed["variables"], list): - raise ValueError( - "Expected 'variables' to be a list." - ) + raise ValueError("Expected 'variables' to be a list.") if not isinstance(parsed["opening_statement"], str): - raise ValueError( - "Expected 'opening_statement' to be a str." - ) + raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException( - f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" - ) - + raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 3f046c68fc..182aeed98f 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -6,7 +6,6 @@ from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCT class SuggestedQuestionsAfterAnswerOutputParser: - def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -15,7 +14,7 @@ class SuggestedQuestionsAfterAnswerOutputParser: if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) else: - json_obj= [] + json_obj = [] print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index dbd6e26c7c..c40b6d1808 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -59,26 +59,29 @@ User Input: yo, 你今天咋样? } User Input: -""" +""" # noqa: E501 SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keeping each question under 20 characters.\n" - "MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" + "MAKE SURE your output is the SAME language as the Assistant's latest response" + "(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" "The output must be an array in JSON format following the specified schema:\n" - "[\"question1\",\"question2\",\"question3\"]\n" + '["question1","question2","question3"]\n' ) GENERATOR_QA_PROMPT = ( - ' The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step.' - 'Step 1: Understand and summarize the main content of this text.\n' - 'Step 2: What key information or concepts are mentioned in this text?\n' - 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' - 'Step 4: Generate questions and answers based on these key information and concepts.\n' - ' The questions should be clear and detailed, and the answers should be detailed and complete. ' - 'You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n' - ' Use the following format: Q1:\nA1:\nQ2:\nA2:...\n' - '' + " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" + " in the long text. Please think step by step." + "Step 1: Understand and summarize the main content of this text.\n" + "Step 2: What key information or concepts are mentioned in this text?\n" + "Step 3: Decompose or combine multiple pieces of information and concepts.\n" + "Step 4: Generate questions and answers based on these key information and concepts.\n" + " The questions should be clear and detailed, and the answers should be detailed and complete. " + "You must answer in {language}, in a style that is clear and detailed in {language}." + " No language other than {language} should be used. \n" + " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "" ) WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ @@ -94,7 +97,7 @@ Based on task description, please create a well-structured prompt template that - Use the same language as task description. - Output in ``` xml ``` and start with Please generate the full prompt template with at least 300 words and output only the prompt template. -""" +""" # noqa: E501 RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ Here is a task description for which I would like you to create a high-quality prompt template for: @@ -109,7 +112,7 @@ Based on task description, please create a well-structured prompt template that - Use the same language as task description. - Output in ``` xml ``` and start with Please generate the full prompt template and output only the prompt template. -""" +""" # noqa: E501 RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """ I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. @@ -134,7 +137,7 @@ Inside XML tags, there is a text that I should extract parameters ### Answer I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text. -""" +""" # noqa: E501 RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """ @@ -150,4 +153,4 @@ Welcome! I'm here to assist you with any questions or issues you might have with Here is the task description: {{INPUT_TEXT}} You just need to generate the output -""" +""" # noqa: E501 diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index b33d4dd7cb..d3185c3b11 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -21,8 +21,9 @@ class TokenBufferMemory: self.conversation = conversation self.model_instance = model_instance - def get_history_prompt_messages(self, max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> list[PromptMessage]: + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> list[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -31,52 +32,41 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id - ).filter( - Message.conversation_id == self.conversation.id, - Message.answer != '' - ).order_by(Message.created_at.desc()) + query = ( + db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id) + .filter(Message.conversation_id == self.conversation.id, Message.answer != "") + .order_by(Message.created_at.desc()) + ) if message_limit and message_limit > 0: - message_limit = message_limit if message_limit <= 500 else 500 + message_limit = min(message_limit, 500) else: message_limit = 500 messages = query.limit(message_limit).all() messages = list(reversed(messages)) - message_file_parser = MessageFileParser( - tenant_id=app_record.tenant_id, - app_id=app_record.id - ) + message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: - workflow_run = (db.session.query(WorkflowRun) - .filter(WorkflowRun.id == message.workflow_run_id).first()) + workflow_run = ( + db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + ) if workflow_run: file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, - is_vision=False + workflow_run.workflow.features_dict, is_vision=False ) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] @@ -97,24 +87,23 @@ class TokenBufferMemory: return [] # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages)>1: + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> str: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: """ Get history prompt text. :param human_prefix: human prefix @@ -123,10 +112,7 @@ class TokenBufferMemory: :param message_limit: message limit :return: """ - prompt_messages = self.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit - ) + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) string_messages = [] for m in prompt_messages: diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index bba004a32a..92da53c9a4 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -18,12 +18,21 @@ class Callback: Base class for callbacks. Only for LLM. """ + raise_error: bool = False - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -39,10 +48,19 @@ class Callback: """ raise NotImplementedError() - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -59,10 +77,19 @@ class Callback: """ raise NotImplementedError() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -79,10 +106,19 @@ class Callback: """ raise NotImplementedError() - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -99,9 +135,7 @@ class Callback: """ raise NotImplementedError() - def print_text( - self, text: str, color: Optional[str] = None, end: str = "" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 0406853b88..3b6b825244 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -10,11 +10,20 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) + class LoggingCallback(Callback): - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -28,40 +37,49 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_before_invoke]\n", color='blue') - self.print_text(f"Model: {model}\n", color='blue') - self.print_text("Parameters:\n", color='blue') + self.print_text("\n[on_llm_before_invoke]\n", color="blue") + self.print_text(f"Model: {model}\n", color="blue") + self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color='blue') + self.print_text(f"\t{key}: {value}\n", color="blue") if stop: - self.print_text(f"\tstop: {stop}\n", color='blue') + self.print_text(f"\tstop: {stop}\n", color="blue") if tools: - self.print_text("\tTools:\n", color='blue') + self.print_text("\tTools:\n", color="blue") for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color='blue') + self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color='blue') + self.print_text(f"Stream: {stream}\n", color="blue") if user: - self.print_text(f"User: {user}\n", color='blue') + self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color='blue') + self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue') - self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue') + self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") + self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") if stream: self.print_text("\n[on_llm_new_chunk]") - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -79,10 +97,19 @@ class LoggingCallback(Callback): sys.stdout.write(chunk.delta.message.content) sys.stdout.flush() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -97,24 +124,33 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_after_invoke]\n", color='yellow') - self.print_text(f"Content: {result.message.content}\n", color='yellow') + self.print_text("\n[on_llm_after_invoke]\n", color="yellow") + self.print_text(f"Content: {result.message.content}\n", color="yellow") if result.message.tool_calls: - self.print_text("Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color='yellow') - self.print_text(f"\t{tool_call.function.name}\n", color='yellow') - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow') + self.print_text(f"\t{tool_call.id}\n", color="yellow") + self.print_text(f"\t{tool_call.function.name}\n", color="yellow") + self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - self.print_text(f"Model: {result.model}\n", color='yellow') - self.print_text(f"Usage: {result.usage}\n", color='yellow') - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow') + self.print_text(f"Model: {result.model}\n", color="yellow") + self.print_text(f"Usage: {result.usage}\n", color="yellow") + self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -129,5 +165,5 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_invoke_error]\n", color='red') + self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 175c13cfdc..659ad59bd6 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None en_US: str diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index e04d9fcbbb..4d0c9aa08f 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -2,123 +2,129 @@ from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { - 'label': { - 'en_US': 'Temperature', - 'zh_Hans': '温度', + "label": { + "en_US": "Temperature", + "zh_Hans": "温度", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.', - 'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。', + "type": "float", + "help": { + "en_US": "Controls randomness. Lower temperature results in less random completions." + " As the temperature approaches zero, the model will become deterministic and repetitive." + " Higher temperature results in more random completions.", + "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。" + "较高的温度会导致更多的随机完成。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_P: { - 'label': { - 'en_US': 'Top P', - 'zh_Hans': 'Top P', + "label": { + "en_US": "Top P", + "zh_Hans": "Top P", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', - 'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。', + "type": "float", + "help": { + "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" + " are considered.", + "zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。", }, - 'required': False, - 'default': 1.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 1.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_K: { - 'label': { - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "label": { + "en_US": "Top K", + "zh_Hans": "Top K", }, - 'type': 'int', - 'help': { - 'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.', - 'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。', + "type": "int", + "help": { + "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", + "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", }, - 'required': False, - 'default': 50, - 'min': 1, - 'max': 100, - 'precision': 0, + "required": False, + "default": 50, + "min": 1, + "max": 100, + "precision": 0, }, DefaultParameterName.PRESENCE_PENALTY: { - 'label': { - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', + "label": { + "en_US": "Presence Penalty", + "zh_Hans": "存在惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens already in the text.', - 'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens already in the text.", + "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.FREQUENCY_PENALTY: { - 'label': { - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', + "label": { + "en_US": "Frequency Penalty", + "zh_Hans": "频率惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.', - 'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", + "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.MAX_TOKENS: { - 'label': { - 'en_US': 'Max Tokens', - 'zh_Hans': '最大标记', + "label": { + "en_US": "Max Tokens", + "zh_Hans": "最大标记", }, - 'type': 'int', - 'help': { - 'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.', - 'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。', + "type": "int", + "help": { + "en_US": "Specifies the upper limit on the length of generated results." + " If the generated results are truncated, you can increase this parameter.", + "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", }, - 'required': False, - 'default': 64, - 'min': 1, - 'max': 2048, - 'precision': 0, + "required": False, + "default": 64, + "min": 1, + "max": 2048, + "precision": 0, }, DefaultParameterName.RESPONSE_FORMAT: { - 'label': { - 'en_US': 'Response Format', - 'zh_Hans': '回复格式', + "label": { + "en_US": "Response Format", + "zh_Hans": "回复格式", }, - 'type': 'string', - 'help': { - 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', - 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + "type": "string", + "help": { + "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," + " such as JSON, XML, etc.", + "zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等", }, - 'required': False, - 'options': ['JSON', 'XML'], + "required": False, + "options": ["JSON", "XML"], }, DefaultParameterName.JSON_SCHEMA: { - 'label': { - 'en_US': 'JSON Schema', + "label": { + "en_US": "JSON Schema", }, - 'type': 'text', - 'help': { - 'en_US': 'Set a response json schema will ensure LLM to adhere it.', - 'zh_Hans': '设置返回的json schema,llm将按照它返回', + "type": "text", + "help": { + "en_US": "Set a response json schema will ensure LLM to adhere it.", + "zh_Hans": "设置返回的json schema,llm将按照它返回", }, - 'required': False, + "required": False, }, } diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 59a4c103a2..52b590f66a 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -12,11 +12,12 @@ class LLMMode(Enum): """ Enum class for large language model mode. """ + COMPLETION = "completion" CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'LLMMode': + def value_of(cls, value: str) -> "LLMMode": """ Get value of given mode. @@ -26,13 +27,14 @@ class LLMMode(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class LLMUsage(ModelUsage): """ Model class for llm usage. """ + prompt_tokens: int prompt_unit_price: Decimal prompt_price_unit: Decimal @@ -50,20 +52,20 @@ class LLMUsage(ModelUsage): def empty_usage(cls): return cls( prompt_tokens=0, - prompt_unit_price=Decimal('0.0'), - prompt_price_unit=Decimal('0.0'), - prompt_price=Decimal('0.0'), + prompt_unit_price=Decimal("0.0"), + prompt_price_unit=Decimal("0.0"), + prompt_price=Decimal("0.0"), completion_tokens=0, - completion_unit_price=Decimal('0.0'), - completion_price_unit=Decimal('0.0'), - completion_price=Decimal('0.0'), + completion_unit_price=Decimal("0.0"), + completion_price_unit=Decimal("0.0"), + completion_price=Decimal("0.0"), total_tokens=0, - total_price=Decimal('0.0'), - currency='USD', - latency=0.0 + total_price=Decimal("0.0"), + currency="USD", + latency=0.0, ) - def plus(self, other: 'LLMUsage') -> 'LLMUsage': + def plus(self, other: "LLMUsage") -> "LLMUsage": """ Add two LLMUsage instances together. @@ -85,10 +87,10 @@ class LLMUsage(ModelUsage): total_tokens=self.total_tokens + other.total_tokens, total_price=self.total_price + other.total_price, currency=other.currency, - latency=self.latency + other.latency + latency=self.latency + other.latency, ) - def __add__(self, other: 'LLMUsage') -> 'LLMUsage': + def __add__(self, other: "LLMUsage") -> "LLMUsage": """ Overload the + operator to add two LLMUsage instances. @@ -97,10 +99,12 @@ class LLMUsage(ModelUsage): """ return self.plus(other) + class LLMResult(BaseModel): """ Model class for llm result. """ + model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage @@ -112,6 +116,7 @@ class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ + index: int message: AssistantPromptMessage usage: Optional[LLMUsage] = None @@ -122,6 +127,7 @@ class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ + model: str prompt_messages: list[PromptMessage] system_fingerprint: Optional[str] = None @@ -132,4 +138,5 @@ class NumTokensResult(PriceInfo): """ Model class for number of tokens result. """ + tokens: int diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e8e6963b56..e51bb18deb 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -9,13 +9,14 @@ class PromptMessageRole(Enum): """ Enum class for prompt message. """ + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" @classmethod - def value_of(cls, value: str) -> 'PromptMessageRole': + def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. @@ -25,13 +26,14 @@ class PromptMessageRole(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt message type value {value}') + raise ValueError(f"invalid prompt message type value {value}") class PromptMessageTool(BaseModel): """ Model class for prompt message tool. """ + name: str description: str parameters: dict @@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel): """ Model class for prompt message function. """ - type: str = 'function' + + type: str = "function" function: PromptMessageTool @@ -49,14 +52,16 @@ class PromptMessageContentType(Enum): """ Enum class for prompt message content type. """ - TEXT = 'text' - IMAGE = 'image' + + TEXT = "text" + IMAGE = "image" class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ + type: PromptMessageContentType data: str @@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ + type: PromptMessageContentType = PromptMessageContentType.TEXT @@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ + class DETAIL(Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ + role: PromptMessageRole content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None @@ -101,6 +109,7 @@ class UserPromptMessage(PromptMessage): """ Model class for user prompt message. """ + role: PromptMessageRole = PromptMessageRole.USER @@ -108,14 +117,17 @@ class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ + class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ + class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ + name: str arguments: str @@ -123,7 +135,7 @@ class AssistantPromptMessage(PromptMessage): type: str function: ToolCallFunction - @field_validator('id', mode='before') + @field_validator("id", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -145,10 +157,12 @@ class AssistantPromptMessage(PromptMessage): return True + class SystemPromptMessage(PromptMessage): """ Model class for system prompt message. """ + role: PromptMessageRole = PromptMessageRole.SYSTEM @@ -156,6 +170,7 @@ class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ + role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index d6377d7e88..52ea787c3a 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -11,6 +11,7 @@ class ModelType(Enum): """ Enum class for model type. """ + LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -26,22 +27,22 @@ class ModelType(Enum): :return: model type """ - if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value: + if origin_model_type in {"text-generation", cls.LLM.value}: return cls.LLM - elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: return cls.TEXT_EMBEDDING - elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: + elif origin_model_type in {"reranking", cls.RERANK.value}: return cls.RERANK - elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: return cls.SPEECH2TEXT - elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: + elif origin_model_type in {"tts", cls.TTS.value}: return cls.TTS - elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type in {"text2img", cls.TEXT2IMG.value}: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: - raise ValueError(f'invalid origin model type {origin_model_type}') + raise ValueError(f"invalid origin model type {origin_model_type}") def to_origin_model_type(self) -> str: """ @@ -50,26 +51,28 @@ class ModelType(Enum): :return: origin model type """ if self == self.LLM: - return 'text-generation' + return "text-generation" elif self == self.TEXT_EMBEDDING: - return 'embeddings' + return "embeddings" elif self == self.RERANK: - return 'reranking' + return "reranking" elif self == self.SPEECH2TEXT: - return 'speech2text' + return "speech2text" elif self == self.TTS: - return 'tts' + return "tts" elif self == self.MODERATION: - return 'moderation' + return "moderation" elif self == self.TEXT2IMG: - return 'text2img' + return "text2img" else: - raise ValueError(f'invalid model type {self}') + raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -78,6 +81,7 @@ class ModelFeature(Enum): """ Enum class for llm feature. """ + TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" @@ -89,6 +93,7 @@ class DefaultParameterName(str, Enum): """ Enum class for parameter template variable. """ + TEMPERATURE = "temperature" TOP_P = "top_p" TOP_K = "top_k" @@ -99,7 +104,7 @@ class DefaultParameterName(str, Enum): JSON_SCHEMA = "json_schema" @classmethod - def value_of(cls, value: Any) -> 'DefaultParameterName': + def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. @@ -109,13 +114,14 @@ class DefaultParameterName(str, Enum): for name in cls: if name.value == value: return name - raise ValueError(f'invalid parameter name {value}') + raise ValueError(f"invalid parameter name {value}") class ParameterType(Enum): """ Enum class for parameter type. """ + FLOAT = "float" INT = "int" STRING = "string" @@ -127,6 +133,7 @@ class ModelPropertyKey(Enum): """ Enum class for model property key. """ + MODE = "mode" CONTEXT_SIZE = "context_size" MAX_CHUNKS = "max_chunks" @@ -144,6 +151,7 @@ class ProviderModel(BaseModel): """ Model class for provider model. """ + model: str label: I18nObject model_type: ModelType @@ -158,6 +166,7 @@ class ParameterRule(BaseModel): """ Model class for parameter rule. """ + name: str use_template: Optional[str] = None label: I18nObject @@ -175,6 +184,7 @@ class PriceConfig(BaseModel): """ Model class for pricing info. """ + input: Decimal output: Optional[Decimal] = None unit: Decimal @@ -185,6 +195,7 @@ class AIModelEntity(ProviderModel): """ Model class for AI model. """ + parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None @@ -197,6 +208,7 @@ class PriceType(Enum): """ Enum class for price type. """ + INPUT = "input" OUTPUT = "output" @@ -205,6 +217,7 @@ class PriceInfo(BaseModel): """ Model class for price info. """ + unit_price: Decimal unit: Decimal total_amount: Decimal diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index f88f89d588..bfe861a97f 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -12,6 +12,7 @@ class ConfigurateMethod(Enum): """ Enum class for configurate method of provider model. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -20,6 +21,7 @@ class FormType(Enum): """ Enum class for form type. """ + TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" SELECT = "select" @@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel): """ Model class for form show on. """ + variable: str value: str @@ -39,6 +42,7 @@ class FormOption(BaseModel): """ Model class for form option. """ + label: I18nObject value: str show_on: list[FormShowOnObject] = [] @@ -46,15 +50,14 @@ class FormOption(BaseModel): def __init__(self, **data): super().__init__(**data) if not self.label: - self.label = I18nObject( - en_US=self.value - ) + self.label = I18nObject(en_US=self.value) class CredentialFormSchema(BaseModel): """ Model class for credential form schema. """ + variable: str label: I18nObject type: FormType @@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel): """ Model class for provider credential schema. """ + credential_form_schemas: list[CredentialFormSchema] @@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel): """ Model class for model credential schema. """ + model: FieldModelSchema credential_form_schemas: list[CredentialFormSchema] @@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel): """ Simple model class for provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel): """ Model class for provider help. """ + title: I18nObject url: I18nObject @@ -110,6 +117,7 @@ class ProviderEntity(BaseModel): """ Model class for provider. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -138,7 +146,7 @@ class ProviderEntity(BaseModel): icon_small=self.icon_small, icon_large=self.icon_large, supported_model_types=self.supported_model_types, - models=self.models + models=self.models, ) @@ -146,5 +154,6 @@ class ProviderConfig(BaseModel): """ Model class for provider config. """ + provider: str credentials: dict diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/core/model_runtime/entities/rerank_entities.py index d51efd2b3b..99709e1bcd 100644 --- a/api/core/model_runtime/entities/rerank_entities.py +++ b/api/core/model_runtime/entities/rerank_entities.py @@ -5,6 +5,7 @@ class RerankDocument(BaseModel): """ Model class for rerank document. """ + index: int text: str score: float @@ -14,5 +15,6 @@ class RerankResult(BaseModel): """ Model class for rerank result. """ + model: str docs: list[RerankDocument] diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def379..846b89d658 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ + tokens: int total_tokens: int unit_price: Decimal @@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ + model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 0513cfaf67..edfb19c7d0 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -3,6 +3,7 @@ from typing import Optional class InvokeError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -14,24 +15,29 @@ class InvokeError(Exception): class InvokeConnectionError(InvokeError): """Raised when the Invoke returns connection error.""" + description = "Connection Error" class InvokeServerUnavailableError(InvokeError): """Raised when the Invoke returns server unavailable error.""" + description = "Server Unavailable Error" class InvokeRateLimitError(InvokeError): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" class InvokeAuthorizationError(InvokeError): """Raised when the Invoke returns authorization error.""" + description = "Incorrect model credentials provided, please check and try again. " class InvokeBadRequestError(InvokeError): """Raised when the Invoke returns bad request.""" + description = "Bad Request Error" diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 8db79a52bb..7fcd2133f9 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception): """ Credentials validate failed error """ + pass diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 716bb63566..79a1d28ebe 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -66,12 +66,16 @@ class AIModel(ABC): :param error: model invoke error :return: unified error """ - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ") + return invoke_error( + description=( + f"[{provider_name}] Incorrect model credentials provided, please check and try again." + ) + ) return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}") @@ -115,7 +119,7 @@ class AIModel(ABC): if not price_config: raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) return PriceInfo( unit_price=unit_price, @@ -136,24 +140,26 @@ class AIModel(ABC): model_schemas = [] # get module name - model_type = self.__class__.__module__.split('.')[-1] + model_type = self.__class__.__module__.split(".")[-1] # get provider name - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] # get the path of current classes current_path = os.path.abspath(__file__) # get parent path of the current path - provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type) + provider_model_type_path = os.path.join( + os.path.dirname(os.path.dirname(current_path)), provider_name, model_type + ) # get all yaml files path under provider_model_type_path that do not start with __ model_schema_yaml_paths = [ os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) - if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') + if not model_schema_yaml.startswith("__") + and not model_schema_yaml.startswith("_") and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and model_schema_yaml.endswith(".yaml") ] # get _position.yaml file path @@ -165,10 +171,10 @@ class AIModel(ABC): yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] - for parameter_rule in yaml_data.get('parameter_rules', []): - if 'use_template' in parameter_rule: + for parameter_rule in yaml_data.get("parameter_rules", []): + if "use_template" in parameter_rule: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template']) + default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"]) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) copy_default_parameter_rule = default_parameter_rule.copy() copy_default_parameter_rule.update(parameter_rule) @@ -176,31 +182,26 @@ class AIModel(ABC): except ValueError: pass - if 'label' not in parameter_rule: - parameter_rule['label'] = { - 'zh_Hans': parameter_rule['name'], - 'en_US': parameter_rule['name'] - } + if "label" not in parameter_rule: + parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]} new_parameter_rules.append(parameter_rule) - yaml_data['parameter_rules'] = new_parameter_rules + yaml_data["parameter_rules"] = new_parameter_rules - if 'label' not in yaml_data: - yaml_data['label'] = { - 'zh_Hans': yaml_data['model'], - 'en_US': yaml_data['model'] - } + if "label" not in yaml_data: + yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]} - yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value + yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value try: # yaml_data to entity model_schema = AIModelEntity(**yaml_data) except Exception as e: model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml") - raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:' - f' {str(e)}') + raise Exception( + f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}: {str(e)}" + ) # cache model schema model_schemas.append(model_schema) @@ -235,7 +236,9 @@ class AIModel(ABC): return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials( + self, model: str, credentials: Mapping + ) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -261,19 +264,19 @@ class AIModel(ABC): try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and 'max' in default_parameter_rule: - parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min and 'min' in default_parameter_rule: - parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.default and 'default' in default_parameter_rule: - parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision and 'precision' in default_parameter_rule: - parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required and 'required' in default_parameter_rule: - parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help and 'help' in default_parameter_rule: + if not parameter_rule.max and "max" in default_parameter_rule: + parameter_rule.max = default_parameter_rule["max"] + if not parameter_rule.min and "min" in default_parameter_rule: + parameter_rule.min = default_parameter_rule["min"] + if not parameter_rule.default and "default" in default_parameter_rule: + parameter_rule.default = default_parameter_rule["default"] + if not parameter_rule.precision and "precision" in default_parameter_rule: + parameter_rule.precision = default_parameter_rule["precision"] + if not parameter_rule.required and "required" in default_parameter_rule: + parameter_rule.required = default_parameter_rule["required"] + if not parameter_rule.help and "help" in default_parameter_rule: parameter_rule.help = I18nObject( - en_US=default_parameter_rule['help']['en_US'], + en_US=default_parameter_rule["help"]["en_US"], ) if ( parameter_rule.help diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index cfc8942c79..ba88cc1f38 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel): """ Model class for large language model. """ + model_type: ModelType = ModelType.LLM # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -69,7 +77,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -82,7 +90,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) try: @@ -96,7 +104,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -111,7 +119,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) raise self._transform_invoke_error(e) @@ -127,7 +135,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) elif isinstance(result, LLMResult): self._trigger_after_invoke_callbacks( @@ -140,15 +148,23 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) return result - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -171,7 +187,7 @@ if you are not sure about the structure. {{instructions}} -""" +""" # noqa: E501 code_block = model_parameters.get("response_format", "") if not code_block: @@ -183,7 +199,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) model_parameters.pop("response_format") @@ -195,15 +211,16 @@ if you are not sure about the structure. if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", str(prompt_messages[0].content)) + content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content)) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.") + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last text message @@ -216,9 +233,7 @@ if you are not sure about the structure. break else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) response = self._invoke( model=model, @@ -228,33 +243,30 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) else: return self._code_block_mode_stream_processor( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], - input_generator: Generator[LLMResultChunk, None, None] - ) -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor( + self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote @@ -303,16 +315,13 @@ if you are not sure about the structure. prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, - input_generator: Generator[LLMResultChunk, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor_with_backtick( + self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote. This version skips the language identifier that follows the opening triple backticks. @@ -378,18 +387,23 @@ if you are not sure about the structure. prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: """ Invoke result generator @@ -397,9 +411,7 @@ if you are not sure about the structure. :return: result generator """ callbacks = callbacks or [] - prompt_message = AssistantPromptMessage( - content="" - ) + prompt_message = AssistantPromptMessage(content="") usage = None system_fingerprint = None real_model = model @@ -418,7 +430,7 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) prompt_message.content += chunk.delta.message.content @@ -437,8 +449,8 @@ if you are not sure about the structure. model=real_model, prompt_messages=prompt_messages, message=prompt_message, - usage=usage if usage else LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint + usage=usage or LLMUsage.empty_usage(), + system_fingerprint=system_fingerprint, ), credentials=credentials, prompt_messages=prompt_messages, @@ -447,15 +459,21 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) @abstractmethod - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -472,8 +490,13 @@ if you are not sure about the structure. raise NotImplementedError @abstractmethod - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -519,7 +542,9 @@ if you are not sure about the structure. return mode - def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -539,10 +564,7 @@ if you are not sure about the structure. # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -558,16 +580,23 @@ if you are not sure about the structure. total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_before_invoke_callbacks( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger before invoke callbacks @@ -593,7 +622,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -601,11 +630,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") - def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_new_chunk_callbacks( + self, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger new chunk callbacks @@ -632,7 +669,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -640,11 +677,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") - def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_after_invoke_callbacks( + self, + model: str, + result: LLMResult, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger after invoke callbacks @@ -672,7 +717,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -680,11 +725,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") - def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_invoke_error_callbacks( + self, + model: str, + ex: Exception, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger invoke error callbacks @@ -712,7 +765,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -758,11 +811,13 @@ if you are not sure about the structure. # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.FLOAT: if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") @@ -775,16 +830,20 @@ if you are not sure about the structure. else: if parameter_value != round(parameter_value, parameter_rule.precision): raise ValueError( - f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.") + f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}" + f" decimal places." + ) # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.BOOLEAN: if not isinstance(parameter_value, bool): raise ValueError(f"Model Parameter {parameter_name} should be bool.") diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 780460a3f7..4374093de4 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -29,32 +29,32 @@ class ModelProvider(ABC): def get_provider_schema(self) -> ProviderEntity: """ Get provider schema - + :return: provider schema """ if self.provider_schema: return self.provider_schema - + # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) - + # read provider schema from yaml file - yaml_path = os.path.join(current_path, f'{provider_name}.yaml') + yaml_path = os.path.join(current_path, f"{provider_name}.yaml") yaml_data = load_yaml_file(yaml_path) - + try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) except Exception as e: - raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}') + raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}") # cache schema self.provider_schema = provider_schema - + return provider_schema def models(self, model_type: ModelType) -> list[AIModelEntity]: @@ -92,15 +92,15 @@ class ModelProvider(ABC): # get the path of the model type classes base_path = os.path.abspath(__file__) - model_type_name = model_type.value.replace('-', '_') + model_type_name = model_type.value.replace("-", "_") model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name) - model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py') + model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py") if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path): - raise Exception(f'Invalid model type {model_type} for provider {provider_name}') + raise Exception(f"Invalid model type {model_type} for provider {provider_name}") # Dynamic loading {model_type_name}.py file and find the subclass of AIModel - parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) + parent_module = ".".join(self.__class__.__module__.split(".")[:-1]) mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 2b17f292c5..d04414ccb8 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -12,14 +12,13 @@ class ModerationModel(AIModel): """ Model class for moderation model. """ + model_type: ModelType = ModelType.MODERATION # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -37,9 +36,7 @@ class ModerationModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke large language model @@ -50,4 +47,3 @@ class ModerationModel(AIModel): :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 2c86f25180..5fb9604742 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -11,12 +11,19 @@ class RerankModel(AIModel): """ Base Model class for rerank model. """ + model_type: ModelType = ModelType.RERANK - def invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -37,10 +44,16 @@ class RerankModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 4fb11025fe..b6b0b73743 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -12,14 +12,13 @@ class Speech2TextModel(AIModel): """ Model class for speech2text model. """ + model_type: ModelType = ModelType.SPEECH2TEXT # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -35,9 +34,7 @@ class Speech2TextModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -59,4 +56,4 @@ class Speech2TextModel(AIModel): current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to the audio file - return os.path.join(current_dir, 'audio.mp3') + return os.path.join(current_dir, "audio.mp3") diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py index e0f1adb1c4..a5810e2f0e 100644 --- a/api/core/model_runtime/model_providers/__base/text2img_model.py +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -11,14 +11,15 @@ class Text2ImageModel(AIModel): """ Model class for text2img model. """ + model_type: ModelType = ModelType.TEXT2IMG # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model @@ -36,9 +37,9 @@ class Text2ImageModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def _invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 381d2f6cd1..54a4486023 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -13,14 +13,15 @@ class TextEmbeddingModel(AIModel): """ Model class for text embedding model. """ + model_type: ModelType = ModelType.TEXT_EMBEDDING # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model @@ -38,9 +39,9 @@ class TextEmbeddingModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f561..5fe6dda6ad 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,27 +7,28 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: """ - use gpt2 tokenizer to get num tokens + use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() tokens = _tokenizer.encode(text, verbose=False) return len(tokens) - + @staticmethod def get_num_tokens(text: str) -> int: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - + @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 2dfd323a47..70be9322a7 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -15,13 +15,15 @@ class TTSModel(AIModel): """ Model class for TTS model. """ + model_type: ModelType = ModelType.TTS # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -35,14 +37,21 @@ class TTSModel(AIModel): :return: translated audio file """ try: - return self._invoke(model=model, credentials=credentials, user=user, - content_text=content_text, voice=voice, tenant_id=tenant_id) + return self._invoke( + model=model, + credentials=credentials, + user=user, + content_text=content_text, + voice=voice, + tenant_id=tenant_id, + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -71,10 +80,13 @@ class TTSModel(AIModel): if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: voices = model_schema.model_properties[ModelPropertyKey.VOICES] if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if - language and language in d.get('language')] + return [ + {"name": d["name"], "value": d["mode"]} + for d in voices + if language and language in d.get("language") + ] else: - return [{'name': d['name'], 'value': d['mode']} for d in voices] + return [{"name": d["name"], "value": d["mode"]} for d in voices] def _get_model_default_voice(self, model: str, credentials: dict) -> any: """ @@ -123,23 +135,23 @@ class TTSModel(AIModel): return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod - def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): + def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): match = re.compile(pattern) tx = match.finditer(org_text) start = 0 result = [] - one_sentence = '' + one_sentence = "" for i in tx: end = i.regs[0][1] tmp = org_text[start:end] if len(one_sentence + tmp) > max_length: result.append(one_sentence) - one_sentence = '' + one_sentence = "" one_sentence += tmp start = end last_sens = org_text[start:] if last_sens: one_sentence += last_sens - if one_sentence != '': + if one_sentence != "": result.append(one_sentence) return result diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index 325c6c060e..5b12f04a3e 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -20,12 +20,9 @@ class AnthropicProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `claude-3-opus-20240229` model for validate, - model_instance.validate_credentials( - model='claude-3-opus-20240229', - credentials=credentials - ) + model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 81be1a06a7..46e1b415b8 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -51,15 +51,21 @@ if you are not sure about the structure. {{instructions}} -""" +""" # noqa: E501 class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -76,10 +82,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # invoke model return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -96,41 +109,39 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # transform model parameters from completion api of anthropic to chat api - if 'max_tokens_to_sample' in model_parameters: - model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + if "max_tokens_to_sample" in model_parameters: + model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample") # init model client client = Anthropic(**credentials_kwargs) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user) system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get('max_tokens') > 4096: + if model_parameters.get("max_tokens") > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" if tools: - extra_model_kwargs['tools'] = [ - self._transform_tool_prompt(tool) for tool in tools - ] + extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( model=model, messages=prompt_message_dicts, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: # chat model @@ -140,22 +151,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if model_parameters.get('response_format'): + if model_parameters.get("response_format"): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -167,24 +186,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict: - return { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.parameters - } + return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters} - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -197,22 +219,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -228,9 +258,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): tokens = client.count_tokens(prompt) tool_call_inner_prompts_tokens_map = { - 'claude-3-opus-20240229': 395, - 'claude-3-haiku-20240307': 264, - 'claude-3-sonnet-20240229': 159 + "claude-3-opus-20240229": 395, + "claude-3-haiku-20240307": 264, + "claude-3-sonnet-20240229": 159, } if model in tool_call_inner_prompts_tokens_map and tools: @@ -257,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "temperature": 0, "max_tokens": 20, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Union[Message, ToolsBetaMessage], + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm chat response @@ -274,22 +309,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[]) for content in response.content: - if content.type == 'text': + if content.type == "text": assistant_prompt_message.content += content.text - elif content.type == 'tool_use': + elif content.type == "tool_use": tool_call = AssistantPromptMessage.ToolCall( id=content.id, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content.name, - arguments=json.dumps(content.input) - ) + name=content.name, arguments=json.dumps(content.input) + ), ) assistant_prompt_message.tool_calls.append(tool_call) @@ -308,17 +339,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm chat stream response @@ -327,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -338,24 +366,23 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): for chunk in response: if isinstance(chunk, MessageStartEvent): - if hasattr(chunk, 'content_block'): + if hasattr(chunk, "content_block"): content_block = chunk.content_block if isinstance(content_block, dict): - if content_block.get('type') == 'tool_use': + if content_block.get("type") == "tool_use": tool_call = AssistantPromptMessage.ToolCall( - id=content_block.get('id'), - type='function', + id=content_block.get("id"), + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content_block.get('name'), - arguments='' - ) + name=content_block.get("name"), arguments="" + ), ) tool_calls.append(tool_call) - elif hasattr(chunk, 'delta'): + elif hasattr(chunk, "delta"): delta = chunk.delta if isinstance(delta, dict) and len(tool_calls) > 0: - if delta.get('type') == 'input_json_delta': - tool_calls[-1].function.arguments += delta.get('partial_json', '') + if delta.get("type") == "input_json_delta": + tool_calls[-1].function.arguments += delta.get("partial_json", "") elif chunk.message: return_model = chunk.message.model input_tokens = chunk.message.usage.input_tokens @@ -369,29 +396,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform empty tool call arguments to {} for tool_call in tool_calls: if not tool_call.function.arguments: - tool_call.function.arguments = '{}' + tool_call.function.arguments = "{}" yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text or "" full_assistant_content += chunk_text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=chunk_text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_text) index = chunk.index @@ -401,7 +423,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk.index, message=assistant_prompt_message, - ) + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -412,14 +434,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "api_key": credentials['anthropic_api_key'], + "api_key": credentials["anthropic_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if credentials.get('anthropic_api_url'): - credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') - credentials_kwargs['base_url'] = credentials['anthropic_api_url'] + if credentials.get("anthropic_api_url"): + credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/") + credentials_kwargs["base_url"] = credentials["anthropic_api_url"] return credentials_kwargs @@ -452,10 +474,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -465,25 +484,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + raise ValueError( + f"Failed to fetch image data from url {message_content.data}, {ex}" + ) else: data_split = message_content.data.split(";base64,") mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) @@ -492,34 +511,28 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): content = [] if message.tool_calls: for tool_call in message.tool_calls: - content.append({ - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments) - }) + content.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + ) if message.content: - content.append({ - "type": "text", - "text": message.content - }) - + content.append({"type": "text", "text": message.content}) + if prompt_message_dicts[-1]["role"] == "assistant": prompt_message_dicts[-1]["content"].extend(content) else: - prompt_message_dicts.append({ - "role": "assistant", - "content": content - }) + prompt_message_dicts.append({"role": "assistant", "content": content}) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [ + {"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content} + ], } prompt_message_dicts.append(message_dict) else: @@ -576,16 +589,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -601,24 +611,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - anthropic.APIConnectionError, - anthropic.APITimeoutError - ], - InvokeServerUnavailableError: [ - anthropic.InternalServerError - ], - InvokeRateLimitError: [ - anthropic.RateLimitError - ], - InvokeAuthorizationError: [ - anthropic.AuthenticationError, - anthropic.PermissionDeniedError - ], + InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError], + InvokeServerUnavailableError: [anthropic.InternalServerError], + InvokeRateLimitError: [anthropic.RateLimitError], + InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError], InvokeBadRequestError: [ anthropic.BadRequestError, anthropic.NotFoundError, anthropic.UnprocessableEntityError, - anthropic.APIError - ] + anthropic.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py index 42eae6c1e5..516ef8b295 100644 --- a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py @@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel): model=real_model, prompt_messages=prompt_messages, message=prompt_message, - usage=usage if usage else LLMUsage.empty_usage(), + usage=usage or LLMUsage.empty_usage(), system_fingerprint=system_fingerprint, ), credentials=credentials, diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 31c788d226..32a0269af4 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -15,10 +15,10 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) + api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION) credentials_kwargs = { - "api_key": credentials['openai_api_key'], - "azure_endpoint": credentials['openai_api_base'], + "api_key": credentials["openai_api_key"], + "azure_endpoint": credentials["openai_api_base"], "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, @@ -29,24 +29,14 @@ class _CommonAzureOpenAI: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], - InvokeAuthorizationError: [ - openai.AuthenticationError, - openai.PermissionDeniedError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index f4f7d964ef..0dada70cc5 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,11 +14,21 @@ from core.model_runtime.entities.model_entities import ( PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2024-02-15-preview' +AZURE_OPENAI_API_VERSION = "2024-02-15-preview" + +AZURE_DEFAULT_PARAM_SEED_HELP = I18nObject( + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性," + "您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically," + " such that repeated requests with the same seed and parameters should return the same result." + " Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter" + " to monitor changes in the backend.", +) + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( - name='max_tokens', + name="max_tokens", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], ) rule.default = default @@ -34,11 +44,11 @@ class AzureBaseModel(BaseModel): LLM_BASE_MODELS = [ AzureBaseModel( - base_model_name='gpt-35-turbo', + base_model_name="gpt-35-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -53,51 +63,47 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-16k', + base_model_name="gpt-35-turbo-16k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -112,37 +118,37 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=16385) + _get_max_tokens(default=512, min_val=1, max_val=16385), ], pricing=PriceConfig( input=0.003, output=0.004, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-0125', + base_model_name="gpt-35-turbo-0125", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -157,51 +163,47 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4', + base_model_name="gpt-4", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -216,67 +218,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=8192), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.03, output=0.06, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-32k', + base_model_name="gpt-4-32k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -291,67 +283,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=32768), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.06, output=0.12, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-0125-preview', + base_model_name="gpt-4-0125-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -366,67 +348,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-1106-preview', + base_model_name="gpt-4-1106-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -441,67 +413,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini', + base_model_name="gpt-4o-mini", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -517,67 +479,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini-2024-07-18', + base_model_name="gpt-4o-mini-2024-07-18", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -593,79 +545,67 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object', 'json_schema'] + options=["text", "json_object", "json_schema"], ), ParameterRule( - name='json_schema', - label=I18nObject( - en_US='JSON Schema' - ), - type='text', + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='设置返回的json schema,llm将按照它返回', - en_US='Set a response json schema will ensure LLM to adhere it.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), - required=False + required=False, ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o', + base_model_name="gpt-4o", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -681,67 +621,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-05-13', + base_model_name="gpt-4o-2024-05-13", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -757,67 +687,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-08-06', + base_model_name="gpt-4o-2024-08-06", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -833,79 +753,67 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object', 'json_schema'] + options=["text", "json_object", "json_schema"], ), ParameterRule( - name='json_schema', - label=I18nObject( - en_US='JSON Schema' - ), - type='text', + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='设置返回的json schema,llm将按照它返回', - en_US='Set a response json schema will ensure LLM to adhere it.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), - required=False + required=False, ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo', + base_model_name="gpt-4-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -921,67 +829,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo-2024-04-09', + base_model_name="gpt-4-turbo-2024-04-09", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -997,72 +895,60 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-vision-preview', + base_model_name="gpt-4-vision-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, - features=[ - ModelFeature.VISION - ], + features=[ModelFeature.VISION], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: LLMMode.CHAT.value, @@ -1070,67 +956,57 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-instruct', + base_model_name="gpt-35-turbo-instruct", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1140,19 +1016,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1161,16 +1037,16 @@ LLM_BASE_MODELS = [ input=0.0015, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-davinci-003', + base_model_name="text-davinci-003", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1180,19 +1056,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1201,20 +1077,18 @@ LLM_BASE_MODELS = [ input=0.02, output=0.02, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] EMBEDDING_BASE_MODELS = [ AzureBaseModel( - base_model_name='text-embedding-ada-002', + base_model_name="text-embedding-ada-002", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1224,17 +1098,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.0001, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-small', + base_model_name="text-embedding-3-small", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1244,17 +1116,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-large', + base_model_name="text-embedding-3-large", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1264,135 +1134,129 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00013, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] SPEECH2TEXT_BASE_MODELS = [ AzureBaseModel( - base_model_name='whisper-1', + base_model_name="whisper-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={ ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, - ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' - } - ) + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm", + }, + ), ) ] TTS_BASE_MODELS = [ AzureBaseModel( - base_model_name='tts-1', + base_model_name="tts-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='tts-1-hd', + base_model_name="tts-1-hd", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.03, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py index 68977b2266..2e3c6aab05 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class AzureOpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 700935b07b..867f9fec42 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -53,6 +53,12 @@ model_credential_schema: type: select required: true options: + - label: + en_US: 2024-08-01-preview + value: 2024-08-01-preview + - label: + en_US: 2024-07-01-preview + value: 2024-07-01-preview - label: en_US: 2024-05-01-preview value: 2024-05-01-preview diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c0c782e42b..f0033ea051 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -34,16 +34,20 @@ logger = logging.getLogger(__name__) class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - - base_model_name = credentials.get('base_model_name') + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -56,7 +60,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -67,7 +71,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) def get_num_tokens( @@ -75,14 +79,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: - raise ValueError(f'Base Model Name {base_model_name} is invalid') + raise ValueError(f"Base Model Name {base_model_name} is invalid") model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: @@ -92,21 +96,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # text completion model, do not support tool calling content = prompt_messages[0].content assert isinstance(content, str) - return self._num_tokens_from_string(credentials,content) + return self._num_tokens_from_string(credentials, content) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise CredentialsValidateFailedError('Base Model Name is required') + raise CredentialsValidateFailedError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -118,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -127,7 +131,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -137,33 +141,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -172,15 +178,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] ): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -209,24 +212,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] ) -> Generator: - full_text = '' + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text or "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -254,8 +254,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -265,14 +265,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) response_format = model_parameters.get("response_format") @@ -293,7 +299,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] + extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] # extra_model_kwargs['functions'] = [{ # "name": tool.name, # "description": tool.description, @@ -301,10 +307,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # } for tool in tools] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # chat model messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -322,9 +328,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) def _handle_chat_generate_response( - self, model: str, credentials: dict, response: ChatCompletion, + self, + model: str, + credentials: dict, + response: ChatCompletion, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): assistant_message = response.choices[0].message assistant_message_tool_calls = assistant_message.tool_calls @@ -334,10 +343,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -369,13 +375,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): credentials: dict, response: Stream[ChatCompletionChunk], prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): index = 0 - full_assistant_content = '' + full_assistant_content = "" real_model = model system_fingerprint = None - completion = '' + completion = "" tool_calls = [] for chunk in response: if len(chunk.choices) == 0: @@ -386,7 +392,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if delta.delta is None: continue - # extract tool calls from response self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) @@ -395,16 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else '' + completion += delta.delta.content or "" yield LLMResultChunk( model=real_model, @@ -413,17 +415,15 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) - index += 0 + index += 1 # calculate num tokens prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - full_assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + full_assistant_prompt_message = AssistantPromptMessage(content=completion) completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) # transform usage @@ -434,27 +434,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, system_fingerprint=system_fingerprint, delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=''), - finish_reason='stop', - usage=usage - ) + index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage + ), ) @staticmethod - def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: + def _update_tool_calls( + tool_calls: list[AssistantPromptMessage.ToolCall], + tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], + ) -> None: if tool_calls_response: for response_tool_call in tool_calls_response: if isinstance(response_tool_call, ChatCompletionMessageToolCall): function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) elif isinstance(response_tool_call, ChoiceDeltaToolCall): @@ -463,8 +460,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tool_calls[index].id = response_tool_call.id or tool_calls[index].id tool_calls[index].type = response_tool_call.type or tool_calls[index].type if response_tool_call.function: - tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name - tool_calls[index].function.arguments += response_tool_call.function.arguments or '' + tool_calls[index].function.name = ( + response_tool_call.function.name or tool_calls[index].function.name + ) + tool_calls[index].function.arguments += response_tool_call.function.arguments or "" else: assert response_tool_call.id is not None assert response_tool_call.type is not None @@ -473,13 +472,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): assert response_tool_call.function.arguments is not None function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) @@ -495,19 +491,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -525,7 +515,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): "role": "tool", "name": message.name, "content": message.content, - "tool_call_id": message.tool_call_id + "tool_call_id": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -535,10 +525,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, credentials: dict, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: try: - encoding = tiktoken.encoding_for_model(credentials['base_model_name']) + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") @@ -550,14 +541,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials['base_model_name'] + model = credentials["base_model_name"] try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -591,10 +581,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -626,40 +616,39 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): @staticmethod def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters['title'])) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters['type'])) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters['properties'].items(): + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) + num_tokens += len(encoding.encode(parameters["title"])) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode(parameters["type"])) + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters["properties"].items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 8aebcb90e4..a2b14cf3db 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -40,7 +38,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -65,10 +63,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): return response.text def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e073bef014..d9cff8ecbb 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -16,19 +16,18 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - base_model_name = credentials['base_model_name'] + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + base_model_name = credentials["base_model_name"] credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -44,11 +43,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -56,10 +53,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): for i in _iter: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -75,10 +69,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,24 +79,16 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=base_model_name - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: if len(texts) == 0: return 0 try: - enc = tiktoken.encoding_for_model(credentials['base_model_name']) + enc = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: enc = tiktoken.get_encoding("cl100k_base") @@ -118,57 +101,52 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - if not self._get_ai_model_entity(credentials['base_model_name'], model): + if not self._get_ai_model_entity(credentials["base_model_name"], model): raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') try: credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity @staticmethod - def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +157,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index f9ddd86f68..af178703a0 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -30,13 +31,12 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -50,14 +50,13 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name @@ -75,23 +74,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] - for index, future in enumerate(futures): - yield from future.result().__enter__().iter_bytes(1024) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] + for future in futures: + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -108,10 +113,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): return response.read() def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None: for ai_model_entity in TTS_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 71bd6b5d92..626fc811cf 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class BaichuanProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `baichuan2-turbo` model for validate, - model_instance.validate_credentials( - model='baichuan2-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml index c6c6c7e9e9..d9cd086e82 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml @@ -33,7 +33,7 @@ parameter_rules: - name: res_format label: zh_Hans: 回复格式 - en_US: response format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml index ee8a9ff0d5..58f9b39a43 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml @@ -33,7 +33,7 @@ parameter_rules: - name: res_format label: zh_Hans: 回复格式 - en_US: response format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml index e5e6aeb491..6a1135e165 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml @@ -33,7 +33,7 @@ parameter_rules: - name: res_format label: zh_Hans: 回复格式 - en_US: response format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb60..a7ca28d49d 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -4,17 +4,18 @@ import re class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: - return len(re.findall(r'[\u4e00-\u9fa5]', text)) + return len(re.findall(r"[\u4e00-\u9fa5]", text)) @classmethod def count_english_vocabularies(cls, text: str) -> int: # remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc. - text = re.sub(r'[^a-zA-Z0-9\s]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # count the number of words not characters return len(text.split()) - + @classmethod def _get_num_tokens(cls, text: str) -> int: - # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) + # tokens = number of Chinese characters + number of English words * 1.3 + # (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index a8fd9dce91..d5fda73009 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -7,7 +7,7 @@ from requests import post from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -45,7 +45,7 @@ class BaichuanModel: parameters: dict[str, Any], tools: Optional[list[PromptMessageTool]] = None, ) -> dict[str, Any]: - if model in self._model_mapping.keys(): + if model in self._model_mapping: # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. # we need to rename it to res_format to get its value if parameters.get("res_format") == "json_object": @@ -94,8 +94,7 @@ class BaichuanModel: timeout: int, tools: Optional[list[PromptMessageTool]] = None, ) -> Union[Iterator, dict]: - - if model in self._model_mapping.keys(): + if model in self._model_mapping: api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") @@ -120,14 +119,12 @@ class BaichuanModel: err = resp["error"]["type"] msg = resp["error"]["message"] except Exception as e: - raise InternalServerError( - f"Failed to convert response to json: {e} with text: {response.text}" - ) + raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") if err == "invalid_api_key": raise InvalidAPIKeyError(msg) elif err == "insufficient_quota": - raise InsufficientAccountBalance(msg) + raise InsufficientAccountBalanceError(msg) elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) elif err == "invalid_request_error": diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a29..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): + +class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 36c7003d1b..91a14bf100 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -38,17 +38,16 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor class BaichuanLanguageModel(LargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, ) -> LLMResult | Generator: return self._generate( model=model, @@ -60,17 +59,17 @@ class BaichuanLanguageModel(LargeLanguageModel): ) def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, ) -> int: return self._num_tokens_from_messages(prompt_messages) def _num_tokens_from_messages( - self, - messages: list[PromptMessage], + self, + messages: list[PromptMessage], ) -> int: """Calculate num tokens for baichuan model""" @@ -111,18 +110,13 @@ class BaichuanLanguageModel(LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} else: raise ValueError(f"Unknown message type {type(message)}") @@ -146,15 +140,14 @@ class BaichuanLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(f"Invalid API key: {e}") def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stream: bool = True, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stream: bool = True, ) -> LLMResult | Generator: - instance = BaichuanModel(api_key=credentials["api_key"]) messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -169,23 +162,19 @@ class BaichuanLanguageModel(LargeLanguageModel): ) if stream: - return self._handle_chat_generate_stream_response( - model, prompt_messages, credentials, response - ) + return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) - return self._handle_chat_generate_response( - model, prompt_messages, credentials, response - ) + return self._handle_chat_generate_response(model, prompt_messages, credentials, response) def _handle_chat_generate_response( - self, - model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: dict, + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: dict, ) -> LLMResult: choices = response.get("choices", []) - assistant_message = AssistantPromptMessage(content='', tool_calls=[]) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if choices and choices[0]["finish_reason"] == "tool_calls": for choice in choices: for tool_call in choice["message"]["tool_calls"]: @@ -194,7 +183,7 @@ class BaichuanLanguageModel(LargeLanguageModel): type=tool_call.get("type", ""), function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.get("function", {}).get("name", ""), - arguments=tool_call.get("function", {}).get("arguments", "") + arguments=tool_call.get("function", {}).get("arguments", ""), ), ) assistant_message.tool_calls.append(tool) @@ -228,11 +217,11 @@ class BaichuanLanguageModel(LargeLanguageModel): ) def _handle_chat_generate_stream_response( - self, - model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Iterator, + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Iterator, ) -> Generator: for line in response: if not line: @@ -260,9 +249,7 @@ class BaichuanLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=choice["delta"]["content"], tool_calls=[] - ), + message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]), finish_reason=stop_reason, ), ) @@ -302,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel): InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 81bd58e3ce..779dfbb608 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -31,11 +31,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for BaiChuan text embedding model. """ - api_base: str = 'http://api.baichuan-ai.com/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "http://api.baichuan-ai.com/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -45,28 +46,23 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] - if model != 'baichuan-text-embedding': - raise ValueError('Invalid model name') + api_key = credentials["api_key"] + if model != "baichuan-text-embedding": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - + raise CredentialsValidateFailedError("api_key is required") + # split into chunks of batch size 16 chunks = [] for i in range(0, len(texts), 16): - chunks.append(texts[i:i + 16]) + chunks.append(texts[i : i + 16]) embeddings = [] token_usage = 0 for chunk in chunks: # embedding chunk - chunk_embeddings, chunk_usage = self.embedding( - model=model, - api_key=api_key, - texts=chunk, - user=user - ) + chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user) embeddings.extend(chunk_embeddings) token_usage += chunk_usage @@ -74,17 +70,14 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - - def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> tuple[list[list[float]], int]: + + def embedding( + self, model: str, api_key, texts: list[str], user: Optional[str] = None + ) -> tuple[list[list[float]], int]: """ Embed given texts @@ -95,56 +88,47 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'Baichuan-Text-Embedding', - 'input': texts - } + data = {"model": "Baichuan-Text-Embedding", "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': - raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) - elif err and 'rate' in err: + elif err == "insufficient_quota": + raise InsufficientAccountBalanceError(msg) + elif err == "invalid_authentication": + raise InvalidAuthenticationError(msg) + elif err and "rate" in err: raise RateLimitReachedError(msg) - elif err and 'internal' in err: + elif err and "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - return [ - data['embedding'] for data in embeddings - ], usage['total_tokens'] - + return [data["embedding"] for data in embeddings], usage["total_tokens"] def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -170,32 +154,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -207,10 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -221,7 +194,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index e99bc52ff8..1cfc1d199c 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,13 +20,10 @@ class BedrockProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `amazon.titan-text-lite-v1` model by default for validating credentials - model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1') - model_instance.validate_credentials( - model=model_for_validation, - credentials=credentials - ) + model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1") + model_instance.validate_credentials(model=model_for_validation, credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml index 53657c08a9..c2d5eb6471 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.00025' output: '0.00125' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml index d083d31e30..f90fa04266 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.015' output: '0.075' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml index 5302231086..dad0d6b6b6 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml index 6995d2bf56..962def8011 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml index 1a3239c85e..70294e4ad3 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml index 0343e3bbec..0a8ea61b6d 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml new file mode 100644 index 0000000000..fe5f54de13 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml @@ -0,0 +1,59 @@ +model: eu.anthropic.claude-3-haiku-20240307-v1:0 +label: + en_US: Claude 3 Haiku(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.00025' + output: '0.00125' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml new file mode 100644 index 0000000000..9f8d029a57 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml @@ -0,0 +1,58 @@ +model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0 +label: + en_US: Claude 3.5 Sonnet(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml new file mode 100644 index 0000000000..bfaf5abb8e --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml @@ -0,0 +1,58 @@ +model: eu.anthropic.claude-3-sonnet-20240229-v1:0 +label: + en_US: Claude 3 Sonnet(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c325ac3cec..06a8606901 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -20,6 +20,7 @@ from botocore.exceptions import ( from PIL.Image import Image # local import +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -44,37 +45,87 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel logger = logging.getLogger(__name__) +ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" # noqa: E501 + class BedrockLargeLanguageModel(LargeLanguageModel): - # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed. - CONVERSE_API_ENABLED_MODEL_INFO=[ - {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} + CONVERSE_API_ENABLED_MODEL_INFO = [ + {"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, ] @staticmethod def _find_model_info(model_id): for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO: - if model_id.startswith(model['prefix']): + if model_id.startswith(model["prefix"]): return model logger.info(f"current model id: {model_id} did not support by Converse API") return None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if model_parameters.get("response_format"): + stop = stop or [] + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + response_format = model_parameters.pop("response_format") + format_prompt = SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) + ) + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = format_prompt + else: + prompt_messages.insert(0, format_prompt) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -88,17 +139,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - - model_info= BedrockLargeLanguageModel._find_model_info(model) + + model_info = BedrockLargeLanguageModel._find_model_info(model) if model_info: - model_info['model'] = model + model_info["model"] = model # invoke models via boto3 converse API - return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) + return self._generate_with_converse( + model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools + ) # invoke other models via boto3 client return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: + def _generate_with_converse( + self, + model_info: dict, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model with converse API @@ -110,35 +172,39 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client(service_name='bedrock-runtime', - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"]) + bedrock_client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key"), + region_name=credentials["aws_region"], + ) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) parameters = { - 'modelId': model_info['model'], - 'messages': prompt_message_dicts, - 'inferenceConfig': inference_config, - 'additionalModelRequestFields': additional_model_fields, + "modelId": model_info["model"], + "messages": prompt_message_dicts, + "inferenceConfig": inference_config, + "additionalModelRequestFields": additional_model_fields, } - if model_info['support_system_prompts'] and system and len(system) > 0: - parameters['system'] = system + if model_info["support_system_prompts"] and system and len(system) > 0: + parameters["system"] = system - if model_info['support_tool_use'] and tools: - parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + if model_info["support_tool_use"] and tools: + parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) try: if stream: response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_stream_response( + model_info["model"], credentials, response, prompt_messages + ) else: response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: @@ -149,8 +215,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - def _handle_converse_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_converse_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -160,36 +228,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: full response chunk generator result """ - response_content = response['output']['message']['content'] + response_content = response["output"]["message"]["content"] # transform assistant message to prompt message - if response['stopReason'] == 'tool_use': + if response["stopReason"] == "tool_use": tool_calls = [] text, tool_use = self._extract_tool_use(response_content) tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=json.dumps(tool_use['input']) - ) + name=tool_use["name"], arguments=json.dumps(tool_use["input"]) + ), ) tool_calls.append(tool_call) - assistant_prompt_message = AssistantPromptMessage( - content=text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls) else: - assistant_prompt_message = AssistantPromptMessage( - content=response_content[0]['text'] - ) + assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"]) # calculate num tokens - if response['usage']: + if response["usage"]: # transform usage - prompt_tokens = response['usage']['inputTokens'] - completion_tokens = response['usage']['outputTokens'] + prompt_tokens = response["usage"]["inputTokens"] + completion_tokens = response["usage"]["outputTokens"] else: # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -206,20 +268,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ) return result - def _extract_tool_use(self, content:dict)-> tuple[str, dict]: + def _extract_tool_use(self, content: dict) -> tuple[str, dict]: tool_use = {} - text = '' + text = "" for item in content: - if 'toolUse' in item: - tool_use = item['toolUse'] - elif 'text' in item: - text = item['text'] + if "toolUse" in item: + tool_use = item["toolUse"] + elif "text" in item: + text = item["text"] else: raise ValueError(f"Got unknown item: {item}") return text, tool_use - def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_converse_stream_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -231,7 +298,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -240,87 +307,85 @@ class BedrockLargeLanguageModel(LargeLanguageModel): tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_use = {} - for chunk in response['stream']: - if 'messageStart' in chunk: + for chunk in response["stream"]: + if "messageStart" in chunk: return_model = model - elif 'messageStop' in chunk: - finish_reason = chunk['messageStop']['stopReason'] - elif 'contentBlockStart' in chunk: - tool = chunk['contentBlockStart']['start']['toolUse'] - tool_use['toolUseId'] = tool['toolUseId'] - tool_use['name'] = tool['name'] - elif 'metadata' in chunk: - input_tokens = chunk['metadata']['usage']['inputTokens'] - output_tokens = chunk['metadata']['usage']['outputTokens'] + elif "messageStop" in chunk: + finish_reason = chunk["messageStop"]["stopReason"] + elif "contentBlockStart" in chunk: + tool = chunk["contentBlockStart"]["start"]["toolUse"] + tool_use["toolUseId"] = tool["toolUseId"] + tool_use["name"] = tool["name"] + elif "metadata" in chunk: + input_tokens = chunk["metadata"]["usage"]["inputTokens"] + output_tokens = chunk["metadata"]["usage"]["outputTokens"] usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - elif 'contentBlockDelta' in chunk: - delta = chunk['contentBlockDelta']['delta'] - if 'text' in delta: - chunk_text = delta['text'] if delta['text'] else '' + elif "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + chunk_text = delta["text"] or "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text or "", ) - index = chunk['contentBlockDelta']['contentBlockIndex'] + index = chunk["contentBlockDelta"]["contentBlockIndex"] yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index+1, + index=index + 1, message=assistant_prompt_message, - ) + ), ) - elif 'toolUse' in delta: - if 'input' not in tool_use: - tool_use['input'] = '' - tool_use['input'] += delta['toolUse']['input'] - elif 'contentBlockStop' in chunk: - if 'input' in tool_use: + elif "toolUse" in delta: + if "input" not in tool_use: + tool_use["input"] = "" + tool_use["input"] += delta["toolUse"]["input"] + elif "contentBlockStop" in chunk: + if "input" in tool_use: tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=tool_use['input'] - ) + name=tool_use["name"], arguments=tool_use["input"] + ), ) tool_calls.append(tool_call) tool_use = {} except Exception as ex: raise InvokeError(str(ex)) - - def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]: + + def _convert_converse_api_model_parameters( + self, model_parameters: dict, stop: Optional[list[str]] = None + ) -> tuple[dict, dict]: inference_config = {} additional_model_fields = {} - if 'max_tokens' in model_parameters: - inference_config['maxTokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters: + inference_config["maxTokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters: - inference_config['temperature'] = model_parameters['temperature'] - - if 'top_p' in model_parameters: - inference_config['topP'] = model_parameters['temperature'] + if "temperature" in model_parameters: + inference_config["temperature"] = model_parameters["temperature"] + + if "top_p" in model_parameters: + inference_config["topP"] = model_parameters["temperature"] if stop: - inference_config['stopSequences'] = stop - - if 'top_k' in model_parameters: - additional_model_fields['top_k'] = model_parameters['top_k'] - + inference_config["stopSequences"] = stop + + if "top_k" in model_parameters: + additional_model_fields["top_k"] = model_parameters["top_k"] + return inference_config, additional_model_fields def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: @@ -332,7 +397,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): prompt_message_dicts = [] for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() system.append({"text": message.content}) else: prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) @@ -349,15 +414,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): "toolSpec": { "name": tool.name, "description": tool.description, - "inputSchema": { - "json": tool.parameters - } + "inputSchema": {"json": tool.parameters}, } } ) tool_config["tools"] = configs return tool_config - + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict @@ -365,15 +428,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": "user", "content": [{'text': message.content}]} + message_dict = {"role": "user", "content": [{"text": message.content}]} else: sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -384,7 +445,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): image_content = requests.get(url).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -393,17 +454,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel): base64_data = data_split[1] image_content = base64.b64decode(base64_data) - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { - "image": { - "format": mime_type.replace('image/', ''), - "source": { - "bytes": image_content - } - } + "image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}} } sub_messages.append(sub_message_dict) @@ -412,36 +470,46 @@ class BedrockLargeLanguageModel(LargeLanguageModel): message = cast(AssistantPromptMessage, message) if message.tool_calls: message_dict = { - "role": "assistant", "content":[{ - "toolUse": { - "toolUseId": message.tool_calls[0].id, - "name": message.tool_calls[0].function.name, - "input": json.loads(message.tool_calls[0].function.arguments) + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": json.loads(message.tool_calls[0].function.arguments), + } } - }] + ], } else: - message_dict = {"role": "assistant", "content": [{'text': message.content}]} + message_dict = {"role": "assistant", "content": [{"text": message.content}]} elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = [{'text': message.content}] + message_dict = [{"text": message.content}] elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "toolResult": { - "toolUseId": message.tool_call_id, - "content": [{"json": {"text": message.content}}] - } - }] + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"json": {"text": message.content}}], + } + } + ], } else: raise ValueError(f"Got unknown type {message}") return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -451,15 +519,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ - prefix = model.split('.')[0] - model_name = model.split('.')[1] - + prefix = model.split(".")[0] + model_name = model.split(".")[1] + if isinstance(prompt_messages, str): prompt = prompt_messages else: prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) - return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -476,30 +543,36 @@ class BedrockLargeLanguageModel(LargeLanguageModel): "max_tokens": 32, } elif "ai21" in model: - # ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again. + # ValidationException: Malformed input request: #/temperature: expected type: Number, + # found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, + # please reformat your input and try again. required_params = { "temperature": 0.7, "topP": 0.9, "maxTokens": 32, } - + try: ping_message = UserPromptMessage(content="ping") - self._invoke(model=model, - credentials=credentials, - prompt_messages=[ping_message], - model_parameters=required_params, - stream=False) - + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters=required_params, + stream=False, + ) + except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_one_message_to_text( + self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Convert a single message to a string. @@ -514,7 +587,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): body = content - if (isinstance(content, list)): + if isinstance(content, list): body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT]) message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}" elif isinstance(message, AssistantPromptMessage): @@ -528,7 +601,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -537,27 +612,31 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message, model_prefix, model_name) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload( + self, + model: str, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} - model_prefix = model.split('.')[0] - model_name = model.split('.')[1] + model_prefix = model.split(".")[0] + model_name = model.split(".")[1] if model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") @@ -571,21 +650,27 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} - + elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - + else: raise ValueError(f"Got unknown model prefix {model_prefix}") - + return payload - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -598,18 +683,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) runtime_client = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) # need workaround for ai21 models which doesn't support streaming @@ -619,18 +702,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) - response = invoke( - modelId=model, - contentType="application/json", - accept= "*/*", - body=body_jsonstr - ) + body_jsonstr = json.dumps(payload) + response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) @@ -639,15 +717,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -657,7 +735,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) finish_reason = response_body.get("error") @@ -665,25 +743,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel): raise InvokeError(finish_reason) # get output text and calculate num tokens based on model / provider - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - output = response_body.get('completions')[0].get('data').get('text') + output = response_body.get("completions")[0].get("data").get("text") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) + elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens(model, credentials, output or "") + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") # construct assistant message from output - assistant_prompt_message = AssistantPromptMessage( - content=output - ) + assistant_prompt_message = AssistantPromptMessage(content=output) # calculate usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -698,8 +774,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -709,65 +786,59 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) - content = response_body.get('completions')[0].get('data').get('text') - finish_reason = response_body.get('completions')[0].get('finish_reason') + content = response_body.get("completions")[0].get("data").get("text") + finish_reason = response_body.get("completions")[0].get("finish_reason") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - finish_reason=finish_reason, - usage=usage - ) - ) + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage + ), + ) return - - stream = response.get('body') + + stream = response.get("body") if not stream: - raise InvokeError('No response body') - + raise InvokeError("No response body") + index = -1 for event in stream: - chunk = event.get('chunk') - + chunk = event.get("chunk") + if not chunk: exception_name = next(iter(event)) full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" raise self._map_client_to_invoke_error(exception_name, full_ex_msg) - payload = json.loads(chunk.get('bytes').decode()) + payload = json.loads(chunk.get("bytes").decode()) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta or "", ) index += 1 - + if not finish_reason: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: @@ -777,18 +848,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage + ), ) - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -804,9 +872,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: """ Map client error to invoke error @@ -818,11 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in { + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml new file mode 100644 index 0000000000..58c1f05779 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml @@ -0,0 +1,59 @@ +model: us.anthropic.claude-3-haiku-20240307-v1:0 +label: + en_US: Claude 3 Haiku(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.00025' + output: '0.00125' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml new file mode 100644 index 0000000000..6b9e1ec067 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml @@ -0,0 +1,59 @@ +model: us.anthropic.claude-3-opus-20240229-v1:0 +label: + en_US: Claude 3 Opus(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.015' + output: '0.075' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml new file mode 100644 index 0000000000..f1e0d6c5a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml @@ -0,0 +1,58 @@ +model: us.anthropic.claude-3-5-sonnet-20240620-v1:0 +label: + en_US: Claude 3.5 Sonnet(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml new file mode 100644 index 0000000000..dce50bf4b5 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml @@ -0,0 +1,58 @@ +model: us.anthropic.claude-3-sonnet-20240229-v1:0 +label: + en_US: Claude 3 Sonnet(Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index ef22a9c868..251170d1ae 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -27,12 +27,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE logger = logging.getLogger(__name__) + class BedrockTextEmbeddingModel(TextEmbeddingModel): - - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -42,67 +41,56 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) embeddings = [] token_usage = 0 - - model_prefix = model.split('.')[0] - - if model_prefix == "amazon" : + + model_prefix = model.split(".")[0] + + if model_prefix == "amazon": for text in texts: body = { - "inputText": text, + "inputText": text, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - logger.warning(f'Total Tokens: {token_usage}') + embeddings.extend([response_body.get("embedding")]) + token_usage += response_body.get("inputTextTokenCount") + logger.warning(f"Total Tokens: {token_usage}") result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - if model_prefix == "cohere" : - input_type = 'search_document' if len(texts) > 1 else 'search_query' + if model_prefix == "cohere": + input_type = "search_document" if len(texts) > 1 else "search_query" for text in texts: body = { - "texts": [text], - "input_type": input_type, + "texts": [text], + "input_type": input_type, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get('embeddings')) + embeddings.extend(response_body.get("embeddings")) token_usage += len(text) result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - #others + # others raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -125,7 +113,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): :param credentials: model credentials :return: """ - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -141,19 +129,25 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - - def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + + def _create_payload( + self, + model_prefix: str, + texts: list[str], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} if model_prefix == "amazon": - payload['inputText'] = texts + payload["inputText"] = texts - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -165,10 +159,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +170,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -195,35 +186,41 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in { + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) return InvokeError(error_msg) - - def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): - accept = 'application/json' - content_type = 'application/json' + def _invoke_bedrock_embedding( + self, + model: str, + bedrock_runtime, + body: dict, + ): + accept = "application/json" + content_type = "application/json" try: response = bedrock_runtime.invoke_model( - body=json.dumps(body), - modelId=model, - accept=accept, - contentType=content_type + body=json.dumps(body), modelId=model, accept=accept, contentType=content_type ) - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) return response_body except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) diff --git a/api/core/model_runtime/model_providers/chatglm/chatglm.py b/api/core/model_runtime/model_providers/chatglm/chatglm.py index e9dd5794f3..71d9a15322 100644 --- a/api/core/model_runtime/model_providers/chatglm/chatglm.py +++ b/api/core/model_runtime/model_providers/chatglm/chatglm.py @@ -20,12 +20,9 @@ class ChatGLMProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `chatglm3-6b` model for validate, - model_instance.validate_credentials( - model='chatglm3-6b', - credentials=credentials - ) + model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index e83d08af71..b3eeb48e22 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -43,12 +43,19 @@ from core.model_runtime.utils import helper logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -71,11 +78,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -96,11 +108,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content="ping"), - ], model_parameters={ - "max_tokens": 16, - }) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ + UserPromptMessage(content="ping"), + ], + model_parameters={ + "max_tokens": 16, + }, + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @@ -124,24 +141,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -163,35 +180,31 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None: if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0: raise InvokeBadRequestError("ChatGLM2 does not support function calling") @@ -212,7 +225,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -223,12 +236,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -239,19 +252,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -265,17 +273,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['api_base']) / 'v1') + "base_url": str(URL(credentials["api_base"]) / "v1"), } return client_kwargs - - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> Generator: - - full_response = '' + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -283,35 +294,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -320,7 +333,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -335,11 +348,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) full_response += delta.delta.content - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -359,15 +376,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -378,7 +394,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) return response - + def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -395,17 +411,19 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer, As a temporary solution we use GPT2 tokenizer instead. """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) - + tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 @@ -414,10 +432,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "function_call": @@ -452,36 +470,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) diff --git a/api/core/model_runtime/model_providers/cohere/cohere.py b/api/core/model_runtime/model_providers/cohere/cohere.py index cfbcb94d26..8394a45fcf 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.py +++ b/api/core/model_runtime/model_providers/cohere/cohere.py @@ -20,12 +20,9 @@ class CohereProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.RERANK) # Use `rerank-english-v2.0` model for validate, - model_instance.validate_credentials( - model='rerank-english-v2.0', - credentials=credentials - ) + model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 89b04c0279..3863ad3308 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -85,7 +91,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: return self._generate( @@ -95,11 +101,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -136,30 +147,37 @@ class CohereLargeLanguageModel(LargeLanguageModel): self._chat_generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) else: self._generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -173,17 +191,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['end_sequences'] = stop + model_parameters["end_sequences"] = stop if stream: response = client.generate_stream( prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_stream_response(model, credentials, response, prompt_messages) @@ -192,14 +210,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Generation, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -212,9 +230,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): assistant_text = response.generations[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens prompt_tokens = int(response.meta.billed_units.input_tokens) @@ -225,17 +241,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[GenerateStreamedResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[GenerateStreamedResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -245,7 +262,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, GenerateStreamedResponse_TextGeneration): chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) @@ -255,9 +272,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -267,7 +282,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -277,9 +292,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) completion_tokens = self._num_tokens_from_messages( - model, - credentials, - [AssistantPromptMessage(content=full_assistant_content)] + model, credentials, [AssistantPromptMessage(content=full_assistant_content)] ) # transform usage @@ -290,20 +303,27 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=chunk.finish_reason, - usage=usage - ) + usage=usage, + ), ) break elif isinstance(chunk, GenerateStreamedResponse_StreamError): chunk = cast(GenerateStreamedResponse_StreamError, chunk) raise InvokeBadRequestError(chunk.err) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -318,27 +338,28 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['stop_sequences'] = stop + model_parameters["stop_sequences"] = stop if tools: if len(tools) == 1: raise ValueError("Cohere tool call requires at least two tools to be specified.") - model_parameters['tools'] = self._convert_tools(tools) + model_parameters["tools"] = self._convert_tools(tools) - message, chat_histories, tool_results \ - = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories( + prompt_messages + ) if tool_results: - model_parameters['tool_results'] = tool_results + model_parameters["tool_results"] = tool_results # chat model real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") if stream: response = client.chat_stream( @@ -346,7 +367,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) @@ -356,14 +377,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_chat_generate_response( + self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -380,19 +401,15 @@ class CohereLargeLanguageModel(LargeLanguageModel): for cohere_tool_call in response.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls) # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) @@ -403,17 +420,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[StreamedChatResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[StreamedChatResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -423,17 +441,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ - def final_response(full_text: str, - tool_calls: list[AssistantPromptMessage.ToolCall], - index: int, - finish_reason: Optional[str] = None) -> LLMResultChunk: + def final_response( + full_text: str, + tool_calls: list[AssistantPromptMessage.ToolCall], + index: int, + finish_reason: Optional[str] = None, + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - full_assistant_prompt_message = AssistantPromptMessage( - content=full_text, - tool_calls=tool_calls - ) + full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) # transform usage @@ -444,14 +461,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content='', tool_calls=tool_calls), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) index = 1 - full_assistant_content = '' + full_assistant_content = "" tool_calls = [] for chunk in response: if isinstance(chunk, StreamedChatResponse_TextGeneration): @@ -462,9 +479,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -474,7 +489,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -484,11 +499,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): for cohere_tool_call in chunk.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) elif isinstance(chunk, StreamedChatResponse_StreamEnd): @@ -496,8 +510,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) index += 1 - def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: + def _convert_prompt_messages_to_message_and_chat_histories( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -510,13 +525,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_message = cast(AssistantPromptMessage, prompt_message) if prompt_message.tool_calls: for tool_call in prompt_message.tool_calls: - latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( - call=ToolCall( - name=tool_call.function.name, - parameters=json.loads(tool_call.function.arguments) - ), - outputs=[] - )) + latest_tool_call_n_outputs.append( + ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments) + ), + outputs=[], + ) + ) else: cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) if cohere_prompt_message: @@ -529,12 +545,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): if tool_call_n_outputs.call.name == prompt_message.tool_call_id: latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( call=ToolCall( - name=tool_call_n_outputs.call.name, - parameters=tool_call_n_outputs.call.parameters + name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters ), - outputs=[{ - "result": prompt_message.content - }] + outputs=[{"result": prompt_message.content}], ) break i += 1 @@ -556,7 +569,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): latest_message = chat_histories.pop() message = latest_message.message else: - raise ValueError('Prompt messages is empty') + raise ValueError("Prompt messages is empty") return message, chat_histories, latest_tool_call_n_outputs @@ -569,7 +582,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): chat_message = ChatMessage(role="USER", message=message.content) else: - sub_message_text = '' + sub_message_text = "" for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) @@ -597,8 +610,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): """ cohere_tools = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] parameter_definitions = {} for p_key, p_val in properties.items(): @@ -606,21 +619,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): if p_key in required_properties: required = True - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: [{', '.join(p_val['enum'])}]" parameter_definitions[p_key] = ToolParameterDefinitionsValue( - description=desc, - type=p_val['type'], - required=required + description=desc, type=p_val["type"], required=required ) cohere_tool = Tool( - name=tool.name, - description=tool.description, - parameter_definitions=parameter_definitions + name=tool.name, description=tool.description, parameter_definitions=parameter_definitions ) cohere_tools.append(cohere_tool) @@ -637,12 +645,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return len(response.tokens) @@ -658,30 +663,30 @@ class CohereLargeLanguageModel(LargeLanguageModel): real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") return self._num_tokens_from_string(real_model, credentials, message_str) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - Cohere supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} - mode = credentials.get('mode') + mode = credentials.get("mode") - if mode == 'chat': - base_model_schema = model_map['command-light-chat'] + if mode == "chat": + base_model_schema = model_map["command-light-chat"] else: - base_model_schema = model_map['command-light'] + base_model_schema = model_map["command-light"] base_model_schema = cast(AIModelEntity, base_model_schema) @@ -691,16 +696,13 @@ class CohereLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity @@ -716,22 +718,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index d2fdb30c6f..aba8fedbc0 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel): Model class for Cohere rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -38,20 +44,17 @@ class CohereRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) response = client.rerank( query=query, documents=docs, model=model, top_n=top_n, return_documents=True, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) rerank_documents = [] @@ -70,10 +73,7 @@ class CohereRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -94,7 +94,7 @@ class CohereRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -110,22 +110,16 @@ class CohereRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 0540fb740f..a1c5e98118 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -24,9 +24,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -46,14 +46,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - tokenize_response = self._tokenize( - model=model, - credentials=credentials, - text=text - ) + tokenize_response = self._tokenize(model=model, credentials=credentials, text=text) for j in range(0, len(tokenize_response), context_size): - tokens += [tokenize_response[j: j + context_size]] + tokens += [tokenize_response[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -62,9 +58,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=["".join(token) for token in tokens[i: i + max_chunks]] + model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]] ) used_tokens += embedding_used_tokens @@ -80,9 +74,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=[" "] + model=model, credentials=credentials, texts=[" "] ) used_tokens += embedding_used_tokens @@ -92,17 +84,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -116,14 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): if len(texts) == 0: return 0 - full_text = ' '.join(texts) + full_text = " ".join(texts) try: - response = self._tokenize( - model=model, - credentials=credentials, - text=full_text - ) + response = self._tokenize(model=model, credentials=credentials, text=full_text) except Exception as e: raise self._transform_invoke_error(e) @@ -141,14 +121,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): return [] # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model, - offline=False, - request_options=RequestOptions(max_retries=0) - ) + response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0)) return response.token_strings @@ -162,11 +137,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -180,14 +151,14 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) # call embedding model response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query', - request_options=RequestOptions(max_retries=1) + input_type="search_document" if len(texts) > 1 else "search_query", + request_options=RequestOptions(max_retries=1), ) return response.embeddings, int(response.meta.billed_units.input_tokens) @@ -203,10 +174,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -217,7 +185,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -233,22 +201,16 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py index d61fd4ddc8..10feef8972 100644 --- a/api/core/model_runtime/model_providers/deepseek/deepseek.py +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py @@ -7,9 +7,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) - class DeepSeekProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,12 +20,9 @@ class DeepSeekProvider(ModelProvider): # Use `deepseek-chat` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='deepseek-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml index 6588a4b5e0..4973ac8ad6 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -62,7 +62,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index bdb3823b60..6d0a3ee262 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -13,12 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -27,10 +32,8 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -48,8 +51,9 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -69,10 +73,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -103,11 +107,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.deepseek.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.deepseek.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" - + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/fishaudio/__init__.py b/api/core/model_runtime/model_providers/fishaudio/__init__.py index 5f282702bb..e69de29bb2 100644 --- a/api/core/model_runtime/model_providers/fishaudio/__init__.py +++ b/api/core/model_runtime/model_providers/fishaudio/__init__.py @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py index 9f80996d9d..3bc4b533e0 100644 --- a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -1,4 +1,4 @@ -import logging +import logging from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -18,11 +18,9 @@ class FishAudioProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.TTS) - model_instance.validate_credentials( - credentials=credentials - ) + model_instance.validate_credentials(credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py index 5b673ce186..895a7a914c 100644 --- a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional import httpx @@ -12,9 +12,7 @@ class FishAudioText2SpeechModel(TTSModel): Model class for Fish.audio Text to Speech model. """ - def get_tts_model_voices( - self, model: str, credentials: dict, language: Optional[str] = None - ) -> list: + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: api_base = credentials.get("api_base", "https://api.fish.audio") api_key = credentials.get("api_key") use_public_models = credentials.get("use_public_models", "false") == "true" @@ -68,9 +66,7 @@ class FishAudioText2SpeechModel(TTSModel): voice=voice, ) - def validate_credentials( - self, credentials: dict, user: Optional[str] = None - ) -> None: + def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: """ Validate credentials for text2speech model @@ -91,9 +87,7 @@ class FishAudioText2SpeechModel(TTSModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming( - self, model: str, credentials: dict, content_text: str, voice: str - ) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ Invoke streaming text2speech model :param model: model name @@ -106,12 +100,10 @@ class FishAudioText2SpeechModel(TTSModel): try: word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: - sentences = self._split_text_into_sentences( - content_text, max_length=word_limit - ) + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) else: sentences = [content_text.strip()] - + for i in range(len(sentences)): yield from self._tts_invoke_streaming_sentence( credentials=credentials, content_text=sentences[i], voice=voice @@ -120,9 +112,7 @@ class FishAudioText2SpeechModel(TTSModel): except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _tts_invoke_streaming_sentence( - self, credentials: dict, content_text: str, voice: Optional[str] = None - ) -> any: + def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> any: """ Invoke streaming text2speech model @@ -141,20 +131,14 @@ class FishAudioText2SpeechModel(TTSModel): with httpx.stream( "POST", api_url + "/v1/tts", - json={ - "text": content_text, - "reference_id": voice, - "latency": latency - }, + json={"text": content_text, "reference_id": voice, "latency": latency}, headers={ "Authorization": f"Bearer {api_key}", }, timeout=None, ) as response: if response.status_code != 200: - raise InvokeBadRequestError( - f"Error: {response.status_code} - {response.text}" - ) + raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}") yield from response.iter_bytes() @property diff --git a/api/core/model_runtime/model_providers/google/google.py b/api/core/model_runtime/model_providers/google/google.py index ba25c74e71..70f56a8337 100644 --- a/api/core/model_runtime/model_providers/google/google.py +++ b/api/core/model_runtime/model_providers/google/google.py @@ -20,12 +20,9 @@ class GoogleProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-pro` model for validate, - model_instance.validate_credentials( - model='gemini-pro', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-pro", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 11f9f32f96..3fc6787a44 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -6,10 +6,10 @@ from collections.abc import Generator from typing import Optional, Union, cast import google.ai.generativelanguage as glm -import google.api_core.exceptions as exceptions import google.generativeai as genai -import google.generativeai.client as client import requests +from google.api_core import exceptions +from google.generativeai import client from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part from PIL import Image @@ -45,16 +45,21 @@ if you are not sure about the structure. {{instructions}} -""" +""" # noqa: E501 class GoogleLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -70,9 +75,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -85,7 +95,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -95,13 +105,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -117,14 +124,16 @@ class GoogleLargeLanguageModel(LargeLanguageModel): type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -136,20 +145,25 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -163,14 +177,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel( - model_name=model - ) + google_model = genai.GenerativeModel(model_name=model) history = [] @@ -180,7 +192,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = self._format_message_to_glm_content(last_msg) history.append(content) else: - for msg in prompt_messages: # makes message roles strictly alternating + for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) @@ -194,7 +206,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): google_model._client = new_custom_client - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, @@ -203,13 +215,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): response = google_model.generate_content( contents=history, - generation_config=genai.types.GenerationConfig( - **config_kwargs - ), + generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, tools=self._convert_tools_to_glm_tool(tools) if tools else None, - request_options={"timeout": 600} + request_options={"timeout": 600}, ) if stream: @@ -217,8 +227,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -229,9 +240,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -250,8 +259,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -264,9 +274,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index = -1 for chunk in response: for part in chunk.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -275,36 +283,31 @@ class GoogleLargeLanguageModel(LargeLanguageModel): assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - + if not response._done: - # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -312,8 +315,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=str(chunk.candidates[0].finish_reason), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -328,17 +331,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -353,65 +352,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: glm Content representation of message """ if isinstance(message, UserPromptMessage): - glm_content = { - "role": "user", - "parts": [] - } - if (isinstance(message.content, str)): - glm_content['parts'].append(to_part(message.content)) + glm_content = {"role": "user", "parts": []} + if isinstance(message.content, str): + glm_content["parts"].append(to_part(message.content)) else: for c in message.content: if c.type == PromptMessageContentType.TEXT: - glm_content['parts'].append(to_part(c.data)) + glm_content["parts"].append(to_part(c.data)) elif c.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, c) if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, base64_data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] else: # fetch image data from url try: image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}} - glm_content['parts'].append(blob) + blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} + glm_content["parts"].append(blob) return glm_content elif isinstance(message, AssistantPromptMessage): - glm_content = { - "role": "model", - "parts": [] - } + glm_content = {"role": "model", "parts": []} if message.content: - glm_content['parts'].append(to_part(message.content)) + glm_content["parts"].append(to_part(message.content)) if message.tool_calls: - glm_content["parts"].append(to_part(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))) + glm_content["parts"].append( + to_part( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ) return glm_content elif isinstance(message, SystemPromptMessage): - return { - "role": "user", - "parts": [to_part(message.content)] - } + return {"role": "user", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", - "parts": [glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))] + "parts": [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], } else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -423,25 +418,20 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -457,5 +447,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py index b3f37b3967..d0d5ff68f8 100644 --- a/api/core/model_runtime/model_providers/groq/groq.py +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class GroqProvider(ModelProvider): +class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ class GroqProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama3-8b-8192', - credentials=credentials - ) + model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py index 915f7a4e1a..352a7b519e 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llm.py +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,6 +27,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index dd8ae526e6..3c4020b6ee 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError class _CommonHuggingfaceHub: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - HfHubHTTPError, - BadRequestError - ] - } + return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]} diff --git a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py index 15e2a4fed4..54d2a2bf39 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class HuggingfaceHubProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index f43a8aedaf..9d29237fdd 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -29,16 +29,23 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] - - if 'baichuan' in model.lower(): + if "baichuan" in model.lower(): stream = False response = client.text_generation( @@ -47,98 +54,100 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel stream=stream, model=model, stop_sequences=stop, - **model_parameters) + **model_parameters, + ) if stream: return self._handle_generate_stream_response(model, credentials, prompt_messages, response) return self._handle_generate_response(model, credentials, prompt_messages, response) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'): - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + credentials["task_type"] = self._get_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) - if credentials['task_type'] not in ("text2text-generation", "text-generation"): - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, ' - 'text-generation.') + if credentials["task_type"] not in {"text2text-generation", "text-generation"}: + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be one of text2text-generation, text-generation." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] try: - client.text_generation( - prompt='Who are you?', - stream=True, - model=model) + client.text_generation(prompt="Who are you?", stream=True, model=model) except BadRequestError as e: - raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. ' - 'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.') + raise CredentialsValidateFailedError( + "Only available for models running on with the `text-generation-inference`. " + "To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference." + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: LLMMode.COMPLETION.value - }, - parameter_rules=self._get_customizable_model_parameter_rules() + model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value}, + parameter_rules=self._get_customizable_model_parameter_rules(), ) return entity @staticmethod def _get_customizable_model_parameter_rules() -> list[ParameterRule]: - temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get( - DefaultParameterName.TEMPERATURE).copy() - temperature_rule_dict['name'] = 'temperature' + temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy() + temperature_rule_dict["name"] = "temperature" temperature_rule = ParameterRule(**temperature_rule_dict) temperature_rule.default = 0.5 top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy() - top_p_rule_dict['name'] = 'top_p' + top_p_rule_dict["name"] = "top_p" top_p_rule = ParameterRule(**top_p_rule_dict) top_p_rule.default = 0.5 top_k_rule = ParameterRule( - name='top_k', + name="top_k", label={ - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "en_US": "Top K", + "zh_Hans": "Top K", }, - type='int', + type="int", help={ - 'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.', - 'zh_Hans': '保留的最高概率词汇标记的数量。', + "en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", + "zh_Hans": "保留的最高概率词汇标记的数量。", }, required=False, default=2, @@ -148,15 +157,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) max_new_tokens = ParameterRule( - name='max_new_tokens', + name="max_new_tokens", label={ - 'en_US': 'Max New Tokens', - 'zh_Hans': '最大新标记', + "en_US": "Max New Tokens", + "zh_Hans": "最大新标记", }, - type='int', + type="int", help={ - 'en_US': 'Maximum number of generated tokens.', - 'zh_Hans': '生成的标记的最大数量。', + "en_US": "Maximum number of generated tokens.", + "zh_Hans": "生成的标记的最大数量。", }, required=False, default=20, @@ -166,30 +175,30 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) seed = ParameterRule( - name='seed', + name="seed", label={ - 'en_US': 'Random sampling seed', - 'zh_Hans': '随机采样种子', + "en_US": "Random sampling seed", + "zh_Hans": "随机采样种子", }, - type='int', + type="int", help={ - 'en_US': 'Random sampling seed.', - 'zh_Hans': '随机采样种子。', + "en_US": "Random sampling seed.", + "zh_Hans": "随机采样种子。", }, required=False, precision=0, ) repetition_penalty = ParameterRule( - name='repetition_penalty', + name="repetition_penalty", label={ - 'en_US': 'Repetition Penalty', - 'zh_Hans': '重复惩罚', + "en_US": "Repetition Penalty", + "zh_Hans": "重复惩罚", }, - type='float', + type="float", help={ - 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', - 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + "en_US": "The parameter for repetition penalty. 1.0 means no penalty.", + "zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。", }, required=False, precision=1, @@ -197,11 +206,9 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - response: Generator) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator + ) -> Generator: index = -1 for chunk in response: # skip special tokens @@ -210,9 +217,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=chunk.token.text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text) if chunk.details: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -240,15 +245,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ), ) - def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any + ) -> LLMResult: if isinstance(response, str): content = response else: content = response.generated_text - assistant_prompt_message = AssistantPromptMessage( - content=content - ) + assistant_prompt_message = AssistantPromptMessage(content=content) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -270,15 +275,14 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = ("text2text-generation", "text-generation") if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") @@ -287,10 +291,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 0f0c166f3e..4ad96c4233 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -13,40 +13,30 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' +HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/" class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) execute_model = model - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - execute_model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + execute_model = credentials["huggingfacehub_endpoint_url"] output = client.post( - json={ - "inputs": texts, - "options": { - "wait_for_model": False, - "use_cache": False - } - }, - model=execute_model) + json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model + ) embeddings = json.loads(output.decode()) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - embeddings=self._mean_pooling(embeddings), - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -56,52 +46,48 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingface_namespace' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingface_namespace" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub User Name / Organization Name must be provided." + ) - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") - if credentials['task_type'] != 'feature-extraction': - raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.') + if credentials["task_type"] != "feature-extraction": + raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.") self._check_endpoint_url_model_repository_name(credentials, model) - model = credentials['huggingfacehub_endpoint_url'] + model = credentials["huggingfacehub_endpoint_url"] - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model) else: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - client.feature_extraction(text='hello world', model=model) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + client.feature_extraction(text="hello world", model=model) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 10000, - 'max_chunks': 1 - } + model_properties={"context_size": 10000, "max_chunks": 1}, ) return entity @@ -128,24 +114,20 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = "feature-extraction" if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -156,7 +138,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -166,25 +148,26 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' headers = { - 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', - 'Content-Type': 'application/json' + "Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', + "Content-Type": "application/json", } response = requests.get(url=url, headers=headers) if response.status_code != 200: - raise ValueError('User Name or Organization Name is invalid.') + raise ValueError("User Name or Organization Name is invalid.") - model_repository_name = '' + model_repository_name = "" for item in response.json().get("items", []): - if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]: model_repository_name = item.get("model", {}).get("repository") break if model_repository_name != model_name: raise ValueError( - f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + f"Model Name {model_name} is invalid. Please check it on the inference endpoints console." + ) except Exception as e: raise ValueError(str(e)) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py index 9454466250..97d7e28dc6 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class HuggingfaceTeiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 34013426de..74a1dfc3ff 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -47,29 +47,28 @@ class HuggingfaceTeiRerankModel(RerankModel): """ if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") try: results = TeiHelper.invoke_rerank(server_url, query, docs) rerank_documents = [] - for result in results: + for result in results: rerank_document = RerankDocument( - index=result['index'], - text=result['text'], - score=result['score'], + index=result["index"], + text=result["text"], + score=result["score"], ) - if score_threshold is None or result['score'] >= score_threshold: + if score_threshold is None or result["score"] >= score_threshold: rerank_documents.append(rerank_document) if top_n is not None and len(rerank_documents) >= top_n: break return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -80,21 +79,21 @@ class HuggingfaceTeiRerankModel(RerankModel): :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) - if extra_args.model_type != 'reranker': - raise CredentialsValidateFailedError('Current model is not a rerank model') + if extra_args.model_type != "reranker": + raise CredentialsValidateFailedError("Current model is not a rerank model") - credentials['context_size'] = extra_args.max_input_length + credentials["context_size"] = extra_args.max_input_length self.invoke( model=model, credentials=credentials, - query='Whose kasumi', + query="Whose kasumi", docs=[ 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', - 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', - 'and she leads a team named PopiParty.', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", ], score_threshold=0.8, ) @@ -129,7 +128,7 @@ class HuggingfaceTeiRerankModel(RerankModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 2aa785c89d..81ab249214 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -31,16 +31,16 @@ class TeiHelper: with cache_lock: if model_name not in cache: cache[model_name] = { - 'expires': time() + 300, - 'value': TeiHelper._get_tei_extra_parameter(server_url), + "expires": time() + 300, + "value": TeiHelper._get_tei_extra_parameter(server_url), } - return cache[model_name]['value'] + return cache[model_name]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -52,40 +52,39 @@ class TeiHelper: get tei model extra parameter like model_type, max_input_length, max_batch_requests """ - url = str(URL(server_url) / 'info') + url = str(URL(server_url) / "info") - # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + # this method is surrounded by a lock, and default requests may hang forever, + # so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) try: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: raise RuntimeError( - f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}" ) response_json = response.json() - model_type = response_json.get('model_type', {}) + model_type = response_json.get("model_type", {}) if len(model_type.keys()) < 1: - raise RuntimeError('model_type is empty') + raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ['embedding', 'reranker']: - raise RuntimeError(f'invalid model_type: {model_type}') - - max_input_length = response_json.get('max_input_length', 512) - max_client_batch_size = response_json.get('max_client_batch_size', 1) + if model_type not in {"embedding", "reranker"}: + raise RuntimeError(f"invalid model_type: {model_type}") + + max_input_length = response_json.get("max_input_length", 512) + max_client_batch_size = response_json.get("max_client_batch_size", 1) return TeiModelExtraParameter( - model_type=model_type, - max_input_length=max_input_length, - max_client_batch_size=max_client_batch_size + model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size ) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: """ @@ -116,12 +115,12 @@ class TeiHelper: :param texts: texts to tokenize """ resp = httpx.post( - f'{server_url}/tokenize', - json={'inputs': texts}, + f"{server_url}/tokenize", + json={"inputs": texts}, ) resp.raise_for_status() return resp.json() - + @staticmethod def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ @@ -149,8 +148,8 @@ class TeiHelper: """ # Use OpenAI compatible API here, which has usage tracking resp = httpx.post( - f'{server_url}/v1/embeddings', - json={'input': texts}, + f"{server_url}/v1/embeddings", + json={"input": texts}, ) resp.raise_for_status() return resp.json() @@ -173,11 +172,11 @@ class TeiHelper: :param texts: texts to rerank :param candidates: candidates to rerank """ - params = {'query': query, 'texts': docs, 'return_text': True} + params = {"query": query, "texts": docs, "return_text": True} response = httpx.post( - server_url + '/rerank', + server_url + "/rerank", json=params, ) - response.raise_for_status() + response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 6897b87f6d..55f3c25804 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -40,11 +40,9 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] - - if server_url.endswith('/'): - server_url = server_url[:-1] + server_url = credentials["server_url"] + server_url = server_url.removesuffix("/") # get model properties context_size = self._get_context_size(model, credentials) @@ -58,7 +56,6 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): - # Check if the number of tokens is larger than the context size num_tokens = len(tokenize_result) @@ -66,20 +63,22 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): # Find the best cutoff point pre_special_token_count = 0 for token in tokenize_result: - if token['special']: + if token["special"]: pre_special_token_count += 1 else: break - rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + rest_special_token_count = ( + len([token for token in tokenize_result if token["special"]]) - pre_special_token_count + ) # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit token_cutoff = context_size - rest_special_token_count - 20 # Find the cutoff index cutpoint_token = tokenize_result[token_cutoff] - cutoff = cutpoint_token['start'] + cutoff = cutpoint_token["start"] - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -92,12 +91,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): for i in _iter: iter_texts = inputs[i : i + max_chunks] results = TeiHelper.invoke_embeddings(server_url, iter_texts) - embeddings = results['data'] - embeddings = [embedding['embedding'] for embedding in embeddings] + embeddings = results["data"] + embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) - usage = results['usage'] - used_tokens += usage['total_tokens'] + usage = results["usage"] + used_tokens += usage["total_tokens"] except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -117,10 +116,9 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :return: """ num_tokens = 0 - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) num_tokens = sum(len(tokens) for tokens in batch_tokens) @@ -135,15 +133,15 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) print(extra_args) - if extra_args.model_type != 'embedding': - raise CredentialsValidateFailedError('Current model is not a embedding model') + if extra_args.model_type != "embedding": + raise CredentialsValidateFailedError("Current model is not a embedding model") - credentials['context_size'] = extra_args.max_input_length - credentials['max_chunks'] = extra_args.max_client_batch_size - self._invoke(model=model, credentials=credentials, texts=['ping']) + credentials["context_size"] = extra_args.max_input_length + credentials["max_chunks"] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -195,8 +193,8 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ - ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py index 5a298d33ac..e65772e7dd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class HunyuanProvider(ModelProvider): +class HunyuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +19,9 @@ class HunyuanProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `hunyuan-standard` model for validate, - model_instance.validate_credentials( - model='hunyuan-standard', - credentials=credentials - ) + model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml index 2c1b981f85..ca8600a534 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml @@ -2,3 +2,4 @@ - hunyuan-standard - hunyuan-standard-256k - hunyuan-pro +- hunyuan-turbo diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml new file mode 100644 index 0000000000..4837fed4ba --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml @@ -0,0 +1,38 @@ +model: hunyuan-turbo +label: + zh_Hans: hunyuan-turbo + en_US: hunyuan-turbo +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.015' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 0bdf6ec005..b57e5e1c2b 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -23,21 +23,27 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) + class HunyuanLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = self._setup_hunyuan_client(credentials) request = models.ChatCompletionsRequest() messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages) custom_parameters = { - 'Temperature': model_parameters.get('temperature', 0.0), - 'TopP': model_parameters.get('top_p', 1.0), - 'EnableEnhancement': model_parameters.get('enable_enhance', True) + "Temperature": model_parameters.get("temperature", 0.0), + "TopP": model_parameters.get("top_p", 1.0), + "EnableEnhancement": model_parameters.get("enable_enhance", True), } params = { @@ -47,16 +53,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): **custom_parameters, } # add Tools and ToolChoice - if (tools and len(tools) > 0): - params['ToolChoice'] = "auto" - params['Tools'] = [{ - "Type": "function", - "Function": { - "Name": tool.name, - "Description": tool.description, - "Parameters": json.dumps(tool.parameters) + if tools and len(tools) > 0: + params["ToolChoice"] = "auto" + params["Tools"] = [ + { + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters), + }, } - } for tool in tools] + for tool in tools + ] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -76,22 +85,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -106,92 +112,97 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): for message in prompt_messages: if isinstance(message, AssistantPromptMessage): tool_calls = message.tool_calls - if (tool_calls and len(tool_calls) > 0): + if tool_calls and len(tool_calls) > 0: dict_tool_calls = [ { "Id": tool_call.id, "Type": tool_call.type, "Function": { "Name": tool_call.function.name, - "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" - } - } for tool_call in tool_calls] - - dict_list.append({ - "Role": message.role.value, - # fix set content = "" while tool_call request - # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. - "Content": " ", # message.content if (message.content is not None) else "", - "ToolCalls": dict_tool_calls - }) + "Arguments": tool_call.function.arguments + if (tool_call.function.arguments == "") + else "{}", + }, + } + for tool_call in tool_calls + ] + + dict_list.append( + { + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter + # message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls, + } + ) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) elif isinstance(message, ToolPromptMessage): - tool_execute_result = { "result": message.content } - content =json.dumps(tool_execute_result, ensure_ascii=False) - dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + tool_execute_result = {"result": message.content} + content = json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id}) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): - tool_call = None tool_calls = [] for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) - data_str = event['data'] + data_str = event["data"] data = json.loads(data_str) - choices = data.get('Choices', []) + choices = data.get("Choices", []) if not choices: continue choice = choices[0] - delta = choice.get('Delta', {}) - message_content = delta.get('Content', '') - finish_reason = choice.get('FinishReason', '') + delta = choice.get("Delta", {}) + message_content = delta.get("Content", "") + finish_reason = choice.get("FinishReason", "") - usage = data.get('Usage', {}) - prompt_tokens = usage.get('PromptTokens', 0) - completion_tokens = usage.get('CompletionTokens', 0) + usage = data.get("Usage", {}) + prompt_tokens = usage.get("PromptTokens", 0) + completion_tokens = usage.get("CompletionTokens", 0) - response_tool_calls = delta.get('ToolCalls') - if (response_tool_calls is not None): + response_tool_calls = delta.get("ToolCalls") + if response_tool_calls is not None: new_tool_calls = self._extract_response_tool_calls(response_tool_calls) - if (len(new_tool_calls) > 0): + if len(new_tool_calls) > 0: new_tool_call = new_tool_calls[0] - if (tool_call is None): tool_call = new_tool_call - elif (tool_call.id != new_tool_call.id): + if tool_call is None: + tool_call = new_tool_call + elif tool_call.id != new_tool_call.id: tool_calls.append(tool_call) tool_call = new_tool_call else: tool_call.function.name += new_tool_call.function.name tool_call.function.arguments += new_tool_call.function.arguments - if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0: tool_calls.append(tool_call) tool_call = None - assistant_prompt_message = AssistantPromptMessage( - content=message_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[]) # rewrite content = "" while tool_call to avoid show content on web page - if (len(tool_calls) > 0): assistant_prompt_message.content = "" - + if len(tool_calls) > 0: + assistant_prompt_message.content = "" + # add tool_calls to assistant_prompt_message - if (finish_reason == 'tool_calls'): + if finish_reason == "tool_calls": assistant_prompt_message.tool_calls = tool_calls tool_call = None tool_calls = [] - if (len(finish_reason) > 0): + if len(finish_reason) > 0: usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) delta_chunk = LLMResultChunkDelta( index=index, - role=delta.get('Role', 'assistant'), + role=delta.get("Role", "assistant"), message=assistant_prompt_message, usage=usage, finish_reason=finish_reason, @@ -212,8 +223,9 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): ) def _handle_chat_response(self, credentials, model, prompt_messages, response): - usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens, - response.Usage.CompletionTokens) + usage = self._calc_response_usage( + model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens + ) assistant_prompt_message = AssistantPromptMessage() assistant_prompt_message.content = response.Choices[0].Message.Content result = LLMResult( @@ -225,8 +237,13 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return result - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if len(prompt_messages) == 0: return 0 prompt = self._convert_messages_to_prompt(prompt_messages) @@ -241,10 +258,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -287,10 +301,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return { InvokeError: [TencentCloudSDKException], } - - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -300,17 +312,14 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): tool_calls = [] if response_tool_calls: for response_tool_call in response_tool_calls: - response_function = response_tool_call.get('Function', {}) + response_function = response_tool_call.get("Function", {}) function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function.get('Name', ''), - arguments=response_function.get('Arguments', '') + name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get('Id', 0), - type='function', - function=function + id=response_tool_call.get("Id", 0), type="function", function=function ) tool_calls.append(tool_call) - return tool_calls \ No newline at end of file + return tool_calls diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 64d8dcf795..1396e59e18 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -19,14 +19,15 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE logger = logging.getLogger(__name__) + class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for Hunyuan text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +38,9 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ - if model != 'hunyuan-embedding': - raise ValueError('Invalid model name') - + if model != "hunyuan-embedding": + raise ValueError("Invalid model name") + client = self._setup_hunyuan_client(credentials) embeddings = [] @@ -47,9 +48,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): for input in texts: request = models.GetEmbeddingRequest() - params = { - "Input": input - } + params = {"Input": input} request.from_json_string(json.dumps(params)) response = client.GetEmbedding(request) usage = response.Usage.TotalTokens @@ -60,11 +59,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result @@ -79,22 +74,19 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -102,7 +94,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): clientProfile.httpProfile = httpProfile client = hunyuan_client.HunyuanClient(cred, "", clientProfile) return client - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -114,10 +106,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -128,11 +117,11 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -146,7 +135,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): return { InvokeError: [TencentCloudSDKException], } - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -170,4 +159,4 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): # response = client.GetTokenCount(request) # num_tokens += response.TokenCount - return num_tokens \ No newline at end of file + return num_tokens diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index cde4313495..33977b6a33 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class JinaProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class JinaProvider(ModelProvider): # Use `jina-embeddings-v2-base-en` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials=credentials - ) + model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index de7e038b9f..79ca68914f 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel): Model class for Jina rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -40,37 +47,31 @@ class JinaRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.jina.ai/v1') - if base_url.endswith('/'): - base_url = base_url[:-1] + base_url = credentials.get("base_url", "https://api.jina.ai/v1") + base_url = base_url.removesuffix("/") try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -81,7 +82,6 @@ class JinaRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -92,7 +92,7 @@ class JinaRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -105,23 +105,21 @@ class JinaRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73ed9..d80cbfa83d 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -14,19 +14,19 @@ class JinaTokenizer: with cls._lock: if cls._tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + gpt2_tokenizer_path = join(dirname(base_path), "tokenizer") cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) return cls._tokenizer @classmethod def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ - use jina tokenizer to get num tokens + use jina tokenizer to get num tokens """ tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - + @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 23203491e6..ef12e534db 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -24,11 +24,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.jina.ai/v1" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -38,29 +39,22 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") - base_url = credentials.get('base_url', self.api_base) - if base_url.endswith('/'): - base_url = base_url[:-1] + base_url = credentials.get("base_url", self.api_base) + base_url = base_url.removesuffix("/") - url = base_url + '/embeddings' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} def transform_jina_input_text(model, text): - if model == 'jina-clip-v1': + if model == "jina-clip-v1": return {"text": text} return text - data = { - 'model': model, - 'input': [transform_jina_input_text(model, text) for text in texts] - } + data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} try: response = post(url, headers=headers, data=dumps(data)) @@ -70,7 +64,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -81,25 +75,20 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): raise InvokeBadRequestError(msg) except JSONDecodeError as e: raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -128,30 +117,18 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: - raise CredentialsValidateFailedError( - f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError, - InvokeBadRequestError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -165,10 +142,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,24 +153,21 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) return entity diff --git a/api/core/model_runtime/model_providers/leptonai/leptonai.py b/api/core/model_runtime/model_providers/leptonai/leptonai.py index b035c31ac5..34a55ff192 100644 --- a/api/core/model_runtime/model_providers/leptonai/leptonai.py +++ b/api/core/model_runtime/model_providers/leptonai/leptonai.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class LeptonAIProvider(ModelProvider): +class LeptonAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ class LeptonAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama2-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="llama2-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/leptonai/llm/llm.py b/api/core/model_runtime/model_providers/leptonai/llm/llm.py index 523309bac5..3d69417e45 100644 --- a/api/core/model_runtime/model_providers/leptonai/llm/llm.py +++ b/api/core/model_runtime/model_providers/leptonai/llm/llm.py @@ -8,18 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_PREFIX_MAP = { - 'llama2-7b': 'llama2-7b', - 'gemma-7b': 'gemma-7b', - 'mistral-7b': 'mistral-7b', - 'mixtral-8x7b': 'mixtral-8x7b', - 'llama3-70b': 'llama3-70b', - 'llama2-13b': 'llama2-13b', - } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + "llama2-7b": "llama2-7b", + "gemma-7b": "gemma-7b", + "mistral-7b": "mistral-7b", + "mixtral-8x7b": "mixtral-8x7b", + "llama3-70b": "llama3-70b", + "llama2-13b": "llama2-13b", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -29,6 +36,5 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1' - \ No newline at end of file + credentials["mode"] = "chat" + credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1" diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 1009995c58..e7295355f6 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -52,29 +52,48 @@ from core.model_runtime.utils import helper class LocalAILanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for baichuan model - LocalAI does not supports + Calculate num tokens for baichuan model + LocalAI does not supports """ def tokens(text: str): """ - We could not determine which tokenizer to use, cause the model is customized. - So we use gpt2 tokenizer to calculate the num tokens for convenience. + We could not determine which tokenizer to use, cause the model is customized. + So we use gpt2 tokenizer to calculate the num tokens for convenience. """ return self._get_num_tokens_by_gpt2(text) @@ -87,10 +106,10 @@ class LocalAILanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -142,30 +161,30 @@ class LocalAILanguageModel(LargeLanguageModel): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -180,102 +199,104 @@ class LocalAILanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={ - 'max_tokens': 10, - }, stop=[], stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + "max_tokens": 10, + }, + stop=[], + stream=False, + ) except Exception as ex: - raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') + raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: completion_model = None - if credentials['completion_type'] == 'chat_completion': + if credentials["completion_type"] == "chat_completion": completion_model = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_model = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=2048, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - model_properties = { - ModelPropertyKey.MODE: completion_model, - } if completion_model else {} + model_properties = ( + { + ModelPropertyKey.MODE: completion_model, + } + if completion_model + else {} + ) - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048")) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) # init model client client = OpenAI(**kwargs) model_name = model - completion_type = credentials['completion_type'] + completion_type = credentials["completion_type"] extra_model_kwargs = { "timeout": 60, } if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] - if completion_type == 'chat_completion': + if completion_type == "chat_completion": result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model_name, @@ -283,36 +304,32 @@ class LocalAILanguageModel(LargeLanguageModel): **model_parameters, **extra_model_kwargs, ) - elif completion_type == 'completion': + elif completion_type == "completion": result = client.completions.create( prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages), model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: raise ValueError(f"Unknown completion type {completion_type}") if stream: - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) def _to_client_kwargs(self, credentials: dict) -> dict: @@ -322,13 +339,13 @@ class LocalAILanguageModel(LargeLanguageModel): :param credentials: credentials dict :return: client kwargs """ - if not credentials['server_url'].endswith('/'): - credentials['server_url'] += '/' + if not credentials["server_url"].endswith("/"): + credentials["server_url"] += "/" client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['server_url']) / 'v1'), + "base_url": str(URL(credentials["server_url"]) / "v1"), } return client_kwargs @@ -349,7 +366,7 @@ class LocalAILanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -359,11 +376,7 @@ class LocalAILanguageModel(LargeLanguageModel): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}], } else: raise ValueError(f"Unknown message type {type(message)}") @@ -374,27 +387,29 @@ class LocalAILanguageModel(LargeLanguageModel): """ Convert PromptMessage to completion prompts """ - prompts = '' + prompts = "" for message in messages: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" else: raise ValueError(f"Unknown message type {type(message)}") return prompts - def _handle_completion_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Completion, - ) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Completion, + ) -> LLMResult: """ Handle llm chat response @@ -411,18 +426,16 @@ class LocalAILanguageModel(LargeLanguageModel): assistant_message = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -434,11 +447,14 @@ class LocalAILanguageModel(LargeLanguageModel): return response - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ChatCompletion, - tools: list[PromptMessageTool]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ChatCompletion, + tools: list[PromptMessageTool], + ) -> LLMResult: """ Handle llm chat response @@ -459,16 +475,14 @@ class LocalAILanguageModel(LargeLanguageModel): tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -480,12 +494,15 @@ class LocalAILanguageModel(LargeLanguageModel): return response - def _handle_completion_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[Completion], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_completion_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[Completion], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -494,17 +511,11 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) @@ -512,8 +523,12 @@ class LocalAILanguageModel(LargeLanguageModel): completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -523,7 +538,7 @@ class LocalAILanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -539,12 +554,15 @@ class LocalAILanguageModel(LargeLanguageModel): full_response += delta.text - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[ChatCompletionChunk], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[ChatCompletionChunk], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -552,7 +570,7 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -560,26 +578,28 @@ class LocalAILanguageModel(LargeLanguageModel): if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -589,7 +609,7 @@ class LocalAILanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -605,9 +625,9 @@ class LocalAILanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -618,15 +638,10 @@ class LocalAILanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls @@ -651,15 +666,9 @@ class LocalAILanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 6d2278fd54..4ff898052b 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class LocalAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py index c8ba9a6c7c..2b0f53bc19 100644 --- a/api/core/model_runtime/model_providers/localai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel): LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -43,45 +50,37 @@ class LocalaiRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model - - if not server_url: - raise CredentialsValidateFailedError('server_url is required') - if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': f"Bearer {credentials.get('api_key')}", - 'Content-Type': 'application/json' - } - data = { - "model": model_name, - "query": query, - "documents": docs, - "top_n": top_n - } + if not server_url: + raise CredentialsValidateFailedError("server_url is required") + if not model_name: + raise CredentialsValidateFailedError("model_name is required") + + url = server_url + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} + + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} try: - response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) - response.raise_for_status() + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -92,7 +91,6 @@ class LocalaiRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -103,7 +101,7 @@ class LocalaiRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -116,21 +114,21 @@ class LocalaiRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={} + model_properties={}, ) return entity diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py index d7403aff4f..4b9d0f5bfe 100644 --- a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -32,8 +32,8 @@ class LocalAISpeech2text(Speech2TextModel): :param user: unique user id :return: text for given audio file """ - - url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + + url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions") data = {"model": model} files = {"file": file} @@ -42,7 +42,7 @@ class LocalAISpeech2text(Speech2TextModel): prepared_request = session.prepare_request(request) response = session.send(prepared_request) - if 'error' in response.json(): + if "error" in response.json(): raise InvokeServerUnavailableError("Empty response") return response.json()["text"] @@ -58,7 +58,7 @@ class LocalAISpeech2text(Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -66,36 +66,24 @@ class LocalAISpeech2text(Speech2TextModel): @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError - ], + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 954c9d10f2..7d258be81e 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -24,9 +24,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,39 +38,33 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ if len(texts) != 1: - raise InvokeBadRequestError('Only one text is supported') + raise InvokeBadRequestError("Only one text is supported") - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': 'Bearer 123', - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("model_name is required") - data = { - 'model': model_name, - 'input': texts[0] - } + url = server_url + headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"} + + data = {"model": model_name, "input": texts[0]} try: - response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - code = resp['error']['code'] - msg = resp['error']['message'] + code = resp["error"]["code"] + msg = resp["error"]["message"] if code == 500: raise InvokeServerUnavailableError(msg) - + if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -79,23 +74,21 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -114,7 +107,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens - + def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema @@ -130,10 +123,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): features=[], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")), ModelPropertyKey.MAX_CHUNKS: 1, }, - parameter_rules=[] + parameter_rules=[], ) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -145,32 +138,22 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid credentials') + raise CredentialsValidateFailedError("Invalid credentials") except InvokeConnectionError as e: - raise CredentialsValidateFailedError(f'Invalid credentials: {e}') + raise CredentialsValidateFailedError(f"Invalid credentials: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -182,10 +165,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -196,7 +176,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 6c41e0d2a5..88cc0e8e0f 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -17,42 +17,48 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxChatCompletion: """ - Minimax Chat Completion API + Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - prompt = '你是一个什么都懂的专家' + prompt = "你是一个什么都懂的专家" - role_meta = { - 'user_name': '我', - 'bot_name': '专家' - } + role_meta = {"user_name": "我", "bot_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: prompt = prompt_messages[0].content @@ -60,44 +66,43 @@ class MinimaxChatCompletion: # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - - messages = [{ - 'sender_type': message.role, - 'text': message.content, - } for message in prompt_messages] + raise BadRequestError("At least one user message is required") - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + messages = [ + { + "sender_type": message.role, + "text": message.content, + } + for message in prompt_messages + ] + + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'prompt': prompt, - 'role_meta': role_meta, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "prompt": prompt, + "role_meta": role_meta, + "stream": stream, + **extra_kwargs, } try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) @@ -110,65 +115,52 @@ class MinimaxChatCompletion: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens - } - message.stop_reason = data['choices'][0]['finish_reason'] + if data["reply"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens} + message.stop_reason = data["choices"][0]["finish_reason"] yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['delta'] - yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + message = choice["delta"] + yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 55747057c9..8b8fdbb6bd 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -17,86 +17,83 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxChatCompletionPro: """ - Minimax Chat Completion Pro API, supports function calling - however, we do not have enough time and energy to implement it, but the parameters are reserved + Minimax Chat Completion Pro API, supports function calling + however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') + raise InvalidAPIKeyError("Invalid API key or group ID") - url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' + url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool: - extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info'] - - if model_parameters.get('plugin_web_search'): - extra_kwargs['plugins'] = [ - 'plugin_web_search' - ] + if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool: + extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"] - bot_setting = { - 'bot_name': '专家', - 'content': '你是一个什么都懂的专家' - } + if model_parameters.get("plugin_web_search"): + extra_kwargs["plugins"] = ["plugin_web_search"] - reply_constraints = { - 'sender_type': 'BOT', - 'sender_name': '专家' - } + bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"} + + reply_constraints = {"sender_type": "BOT", "sender_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') + raise BadRequestError("At least one message is required") if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: - bot_setting['content'] = prompt_messages[0].content + bot_setting["content"] = prompt_messages[0].content prompt_messages = prompt_messages[1:] # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') + raise BadRequestError("At least one user message is required") messages = [message.to_dict() for message in prompt_messages] - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'bot_setting': [bot_setting], - 'reply_constraints': reply_constraints, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "bot_setting": [bot_setting], + "reply_constraints": reply_constraints, + "stream": stream, + **extra_kwargs, } if tools: - body['functions'] = tools - body['function_call'] = {'type': 'auto'} + body["functions"] = tools + body["function_call"] = {"type": "auto"} try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) @@ -108,9 +105,9 @@ class MinimaxChatCompletionPro: return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) @@ -123,78 +120,72 @@ class MinimaxChatCompletionPro: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) # final chunk - if data['reply'] or data.get('usage'): - total_tokens = data['usage']['total_tokens'] - minimax_message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) + if data["reply"] or data.get("usage"): + total_tokens = data["usage"]["total_tokens"] + minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") minimax_message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - minimax_message.stop_reason = data['choices'][0]['finish_reason'] + minimax_message.stop_reason = data["choices"][0]["finish_reason"] - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) > 0: for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append function_call message - if 'function_call' in message: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - function_call_message.function_call = message['function_call'] + if "function_call" in message: + function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = message["function_call"] yield function_call_message yield minimax_message return # partial chunk - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append text message - if 'text' in message: - minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value) + if "text" in message: + minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value) yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/errors.py b/api/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6ca..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/api/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index feeba75f49..4250c40cfb 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -34,18 +34,25 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { - 'abab6.5s-chat': MinimaxChatCompletionPro, - 'abab6.5-chat': MinimaxChatCompletionPro, - 'abab6-chat': MinimaxChatCompletionPro, - 'abab5.5s-chat': MinimaxChatCompletionPro, - 'abab5.5-chat': MinimaxChatCompletionPro, - 'abab5-chat': MinimaxChatCompletion + "abab6.5s-chat": MinimaxChatCompletionPro, + "abab6.5-chat": MinimaxChatCompletionPro, + "abab6-chat": MinimaxChatCompletionPro, + "abab5.5s-chat": MinimaxChatCompletionPro, + "abab5.5-chat": MinimaxChatCompletionPro, + "abab5-chat": MinimaxChatCompletion, } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -53,82 +60,97 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): Validate credentials for Baichuan model """ if model not in self.model_apis: - raise CredentialsValidateFailedError(f'Invalid model: {model}') + raise CredentialsValidateFailedError(f"Invalid model: {model}") - if not credentials.get('minimax_api_key'): - raise CredentialsValidateFailedError('Invalid API key') + if not credentials.get("minimax_api_key"): + raise CredentialsValidateFailedError("Invalid API key") + + if not credentials.get("minimax_group_id"): + raise CredentialsValidateFailedError("Invalid group ID") - if not credentials.get('minimax_group_id'): - raise CredentialsValidateFailedError('Invalid group ID') - # ping instance = MinimaxChatCompletionPro() try: instance.generate( - model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'], - prompt_messages=[ - MinimaxMessage(content='ping', role='USER') - ], + model=model, + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[MinimaxMessage(content="ping", role="USER")], model_parameters={}, - tools=[], stop=[], + tools=[], + stop=[], stream=False, - user='' + user="", ) except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for minimax model + Calculate num tokens for minimax model - not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way - to calculate the num tokens, so we use str() to convert the prompt to string + not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way + to calculate the num tokens, so we use str() to convert the prompt to string - Minimax does not provide their own tokenizer of adab5.5 and abab5 model - therefore, we use gpt2 tokenizer instead + Minimax does not provide their own tokenizer of adab5.5 and abab5 model + therefore, we use gpt2 tokenizer instead """ messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface + use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface """ client: MinimaxChatCompletionPro = self.model_apis[model]() if tools: - tools = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + tools = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] response = client.generate( model=model, - api_key=credentials['minimax_api_key'], - group_id=credentials['minimax_group_id'], + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages], model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage: """ - convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface + convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface """ if isinstance(prompt_message, SystemPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content) @@ -136,26 +158,27 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.function_call={ - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.function_call = { + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } return message return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) elif isinstance(prompt_message, ToolPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -166,31 +189,33 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[MinimaxMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[MinimaxMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) elif message.function_call: - if 'name' not in message.function_call or 'arguments' not in message.function_call: + if "name" not in message.function_call or "arguments" not in message.function_call: continue yield LLMResultChunk( @@ -199,15 +224,16 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content='', - tool_calls=[AssistantPromptMessage.ToolCall( - id='', - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.function_call['name'], - arguments=message.function_call['arguments'] + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call["name"], arguments=message.function_call["arguments"] + ), ) - )] + ], ), ), ) @@ -217,11 +243,8 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=message.content, tool_calls=[]), + finish_reason=message.stop_reason or None, ), ) @@ -236,22 +259,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9ac..88ebe5e2e0 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -4,32 +4,27 @@ from typing import Any class MinimaxMessage: class Role(Enum): - USER = 'USER' - ASSISTANT = 'BOT' - SYSTEM = 'SYSTEM' - FUNCTION = 'FUNCTION' + USER = "USER" + ASSISTANT = "BOT" + SYSTEM = "SYSTEM" + FUNCTION = "FUNCTION" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" function_call: dict[str, Any] = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: - return { - 'sender_type': 'BOT', - 'sender_name': '专家', - 'text': '', - 'function_call': self.function_call - } - + return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call} + return { - 'sender_type': self.role, - 'sender_name': '我' if self.role == 'USER' else '专家', - 'text': self.content, + "sender_type": self.role, + "sender_name": "我" if self.role == "USER" else "专家", + "text": self.content, } - - def __init__(self, content: str, role: str = 'USER') -> None: + + def __init__(self, content: str, role: str = "USER") -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 52f6c2f1d3..5a761903a1 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class MinimaxProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `abab5.5-chat` model for validate, - model_instance.validate_credentials( - model='abab5.5-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise CredentialsValidateFailedError(f'{ex}') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise CredentialsValidateFailedError(f"{ex}") diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 85dc6ef51d..76fd1342bd 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -30,11 +30,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ Model class for Minimax text embedding model. """ - api_base: str = 'https://api.minimax.chat/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.minimax.chat/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -44,54 +45,43 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['minimax_api_key'] - group_id = credentials['minimax_group_id'] - if model != 'embo-01': - raise ValueError('Invalid model name') + api_key = credentials["minimax_api_key"] + group_id = credentials["minimax_group_id"] + if model != "embo-01": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - url = f'{self.api_base}?GroupId={group_id}' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("api_key is required") + url = f"{self.api_base}?GroupId={group_id}" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'embo-01', - 'texts': texts, - 'type': 'db' - } + data = {"model": "embo-01", "texts": texts, "type": "db"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json() # check if there is an error - if resp['base_resp']['status_code'] != 0: - code = resp['base_resp']['status_code'] - msg = resp['base_resp']['status_msg'] + if resp["base_resp"]["status_code"] != 0: + code = resp["base_resp"]["status_code"] + msg = resp["base_resp"]["status_msg"] self._handle_error(code, msg) - embeddings = resp['vectors'] - total_tokens = resp['total_tokens'] + embeddings = resp["vectors"] + total_tokens = resp["total_tokens"] except InvalidAuthenticationError: - raise InvalidAPIKeyError('Invalid api key') + raise InvalidAPIKeyError("Invalid api key") except KeyError as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -119,12 +109,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001: + if code in {1000, 1001}: raise InternalServerError(msg) elif code == 1002: raise RateLimitReachedError(msg) @@ -148,25 +138,17 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -178,10 +160,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +171,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py index 01ed8010de..da60bd7661 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/llm.py +++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -7,14 +7,19 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - + # mistral dose not support user/stop arguments stop = [] user = None @@ -27,5 +32,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.mistral.ai/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.mistral.ai/v1" diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py index f1d825f6c6..7f9db8da1c 100644 --- a/api/core/model_runtime/model_providers/mistralai/mistralai.py +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='open-mistral-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index c233596637..3ea46c2967 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,50 +55,50 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 4096)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + credentials["mode"] = "chat" + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["endpoint_url"] = "https://api.moonshot.cn/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ @@ -107,19 +113,13 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -129,14 +129,16 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -162,21 +164,26 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -186,11 +193,12 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -201,12 +209,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -220,9 +223,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -244,9 +247,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -255,21 +258,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -277,19 +280,18 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -305,26 +307,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml index 1078e84c59..59c0915ee9 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -24,7 +24,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml index 9c739d0501..724f2aa5a2 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -24,7 +24,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml index 187a86999e..5872295bfa 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -24,7 +24,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py index 5654ae1459..4995e235f5 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MoonshotProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MoonshotProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='moonshot-v1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/novita/llm/llm.py b/api/core/model_runtime/model_providers/novita/llm/llm.py index 7662bf914a..23367ed1b4 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llm.py +++ b/api/core/model_runtime/model_providers/novita/llm/llm.py @@ -8,20 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - - credentials['endpoint_url'] = "https://api.novita.ai/v3/openai" - credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' } + credentials["endpoint_url"] = "https://api.novita.ai/v3/openai" + credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"} return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + def validate_credentials(self, model: str, credentials: dict) -> None: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) self._add_custom_parameters(credentials, model) @@ -29,21 +34,36 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' + credentials["mode"] = "chat" - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_customizable_model_schema(model, cred_with_endpoint) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/novita/novita.py b/api/core/model_runtime/model_providers/novita/novita.py index f1b7224605..76a75b01e2 100644 --- a/api/core/model_runtime/model_providers/novita/novita.py +++ b/api/core/model_runtime/model_providers/novita/novita.py @@ -20,12 +20,9 @@ class NovitaProvider(ModelProvider): # Use `meta-llama/llama-3-8b-instruct` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials=credentials - ) + model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index bc42eaca65..1c98c6be6c 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -21,31 +21,36 @@ from core.model_runtime.utils import helper class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_SUFFIX_MAP = { - 'fuyu-8b': 'vlm/adept/fuyu-8b', - 'mistralai/mistral-large': '', - 'mistralai/mixtral-8x7b-instruct-v0.1': '', - 'mistralai/mixtral-8x22b-instruct-v0.1': '', - 'google/gemma-7b': '', - 'google/codegemma-7b': '', - 'snowflake/arctic':'', - 'meta/llama2-70b': '', - 'meta/llama3-8b-instruct': '', - 'meta/llama3-70b-instruct': '', - 'meta/llama-3.1-8b-instruct': '', - 'meta/llama-3.1-70b-instruct': '', - 'meta/llama-3.1-405b-instruct': '', - 'google/recurrentgemma-2b': '', - 'nvidia/nemotron-4-340b-instruct': '', - 'microsoft/phi-3-medium-128k-instruct':'', - 'microsoft/phi-3-mini-128k-instruct':'' + "fuyu-8b": "vlm/adept/fuyu-8b", + "mistralai/mistral-large": "", + "mistralai/mixtral-8x7b-instruct-v0.1": "", + "mistralai/mixtral-8x22b-instruct-v0.1": "", + "google/gemma-7b": "", + "google/codegemma-7b": "", + "snowflake/arctic": "", + "meta/llama2-70b": "", + "meta/llama3-8b-instruct": "", + "meta/llama3-70b-instruct": "", + "meta/llama-3.1-8b-instruct": "", + "meta/llama-3.1-70b-instruct": "", + "meta/llama-3.1-405b-instruct": "", + "google/recurrentgemma-2b": "", + "nvidia/nemotron-4-340b-instruct": "", + "microsoft/phi-3-medium-128k-instruct": "", + "microsoft/phi-3-mini-128k-instruct": "", } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) prompt_messages = self._transform_prompt_messages(prompt_messages) stop = [] @@ -60,16 +65,14 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): for i, p in enumerate(prompt_messages): if isinstance(p, UserPromptMessage) and isinstance(p.content, list): content = p.content - content_text = '' + content_text = "" for prompt_content in content: if prompt_content.type == PromptMessageContentType.TEXT: content_text += prompt_content.data else: content_text += f' ' - prompt_message = UserPromptMessage( - content=content_text - ) + prompt_message = UserPromptMessage(content=content_text) prompt_messages[i] = prompt_message return prompt_messages @@ -78,91 +81,87 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): self._validate_credentials(model, credentials) def _add_custom_parameters(self, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - - if self.MODEL_SUFFIX_MAP[model]: - credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}' - credentials.pop('endpoint_url') - else: - credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1' + credentials["mode"] = "chat" - credentials['stream_mode_delimiter'] = '\n' + if self.MODEL_SUFFIX_MAP[model]: + credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}" + credentials.pop("endpoint_url") + else: + credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1" + + credentials["stream_mode_delimiter"] = "\n" def _validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. + Validate model credentials using requests to ensure compatibility with all providers following + OpenAI's API standard. :param model: model name :param credentials: model credentials :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -176,57 +175,51 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: - headers['Authorization'] = f'Bearer {api_key}' + headers["Authorization"] = f"Bearer {api_key}" if stream: - headers['Accept'] = 'text/event-stream' + headers["Accept"] = "text/event-stream" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") - # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -240,16 +233,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if not response.ok: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.py b/api/core/model_runtime/model_providers/nvidia/nvidia.py index e83f8badb5..058fa00346 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.py +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='mistralai/mixtral-8x7b-instruct-v0.1', - credentials=credentials - ) + model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 80c24b0555..fabebc67ab 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel): """ def _sigmoid(self, logit: float) -> float: - return 1/(1+exp(-logit)) + return 1 / (1 + exp(-logit)) - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -60,9 +67,9 @@ class NvidiaRerankModel(RerankModel): results = response.json() rerank_documents = [] - for result in results['rankings']: - index = result['index'] - logit = result['logit'] + for result in results["rankings"]: + index = result["index"] + logit = result["logit"] rerank_document = RerankDocument( index=index, text=docs[index], @@ -110,5 +117,5 @@ class NvidiaRerankModel(RerankModel): InvokeServerUnavailableError: [requests.HTTPError], InvokeRateLimitError: [], InvokeAuthorizationError: [requests.HTTPError], - InvokeBadRequestError: [requests.RequestException] + InvokeBadRequestError: [requests.RequestException], } diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index a2adef400d..00cec265d5 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -22,12 +22,13 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Nvidia text embedding model. """ - api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings' - models: list[str] = ['NV-Embed-QA'] - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" + models: list[str] = ["NV-Embed-QA"] + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,32 +38,25 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if model not in self.models: - raise InvokeBadRequestError('Invalid model name') + raise InvokeBadRequestError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': model, - 'input': texts[0], - 'input_type': 'query' - } + data = {"model": model, "input": texts[0], "input_type": "query"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -72,23 +66,21 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -117,30 +109,20 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -152,10 +134,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -166,7 +145,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py index f7b849fbe2..6ff380bdd9 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py @@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel): """ Model class for NVIDIA NIM large language model. """ + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py index 25ab3e8e20..ad890ada22 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class NVIDIANIMProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 37787c459d..1e1fc5b3ea 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -33,31 +33,29 @@ logger = logging.getLogger(__name__) request_template = { "compartmentId": "", - "servingMode": { - "modelId": "cohere.command-r-plus", - "servingType": "ON_DEMAND" - }, + "servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"}, "chatRequest": { "apiFormat": "COHERE", - #"preambleOverride": "You are a helpful assistant.", - #"message": "Hello!", - #"chatHistory": [], + # "preambleOverride": "You are a helpful assistant.", + # "message": "Hello!", + # "chatHistory": [], "maxTokens": 600, "isStream": False, "frequencyPenalty": 0, "presencePenalty": 0, "temperature": 1, - "topP": 0.75 - } + "topP": 0.75, + }, } oci_config_template = { - "user": "", - "fingerprint": "", - "tenancy": "", - "region": "", - "compartment_id": "", - "key_content": "" - } + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + class OCILargeLanguageModel(LargeLanguageModel): # https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm @@ -100,11 +98,17 @@ class OCILargeLanguageModel(LargeLanguageModel): return False return feature["system"] - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -118,22 +122,27 @@ class OCILargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - #print("model"+"*"*20) - #print(model) - #print("credentials"+"*"*20) - #print(credentials) - #print("model_parameters"+"*"*20) - #print(model_parameters) - #print("prompt_messages"+"*"*200) - #print(prompt_messages) - #print("tools"+"*"*20) - #print(tools) + # print("model"+"*"*20) + # print(model) + # print("credentials"+"*"*20) + # print(credentials) + # print("model_parameters"+"*"*20) + # print(model_parameters) + # print("prompt_messages"+"*"*200) + # print(prompt_messages) + # print("tools"+"*"*20) + # print(tools) # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -147,8 +156,13 @@ class OCILargeLanguageModel(LargeLanguageModel): return self._get_num_tokens_by_gpt2(prompt) - def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_characters( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -169,10 +183,7 @@ class OCILargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() @@ -192,11 +203,17 @@ class OCILargeLanguageModel(LargeLanguageModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -218,10 +235,13 @@ class OCILargeLanguageModel(LargeLanguageModel): # ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat oci_config = copy.deepcopy(oci_config_template) if "oci_config_content" in credentials: - oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8') + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") config_items = oci_config_content.split("/") if len(config_items) != 5: - raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))") + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode(" + "'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) oci_config["user"] = config_items[0] oci_config["fingerprint"] = config_items[1] oci_config["tenancy"] = config_items[2] @@ -230,12 +250,12 @@ class OCILargeLanguageModel(LargeLanguageModel): else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") if "oci_key_content" in credentials: - oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8') + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") - #oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) compartment_id = oci_config["compartment_id"] client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) # call embedding model @@ -245,9 +265,9 @@ class OCILargeLanguageModel(LargeLanguageModel): chat_history = [] system_prompts = [] - #if "meta.llama" in model: + # if "meta.llama" in model: # request_args["chatRequest"]["apiFormat"] = "GENERIC" - request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600) + request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600) request_args["chatRequest"].update(model_parameters) frequency_penalty = model_parameters.get("frequencyPenalty", 0) presence_penalty = model_parameters.get("presencePenalty", 0) @@ -267,7 +287,7 @@ class OCILargeLanguageModel(LargeLanguageModel): if not valid_value: raise InvokeBadRequestError("Does not support function calling") if model.startswith("cohere"): - #print("run cohere " * 10) + # print("run cohere " * 10) for message in prompt_messages[:-1]: text = "" if isinstance(message.content, str): @@ -279,37 +299,37 @@ class OCILargeLanguageModel(LargeLanguageModel): if isinstance(message, SystemPromptMessage): if isinstance(message.content, str): system_prompts.append(message.content) - args = {"apiFormat": "COHERE", - "preambleOverride": ' '.join(system_prompts), - "message": prompt_messages[-1].content, - "chatHistory": chat_history, } + args = { + "apiFormat": "COHERE", + "preambleOverride": " ".join(system_prompts), + "message": prompt_messages[-1].content, + "chatHistory": chat_history, + } request_args["chatRequest"].update(args) elif model.startswith("meta"): - #print("run meta " * 10) + # print("run meta " * 10) meta_messages = [] for message in prompt_messages: text = message.content meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]}) - args = {"apiFormat": "GENERIC", - "messages": meta_messages, - "numGenerations": 1, - "topK": -1} + args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1} request_args["chatRequest"].update(args) if stream: request_args["chatRequest"]["isStream"] = True - #print("final request" + "|" * 20) - #print(request_args) + # print("final request" + "|" * 20) + # print(request_args) response = client.chat(request_args) - #print(vars(response)) + # print(vars(response)) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -320,9 +340,7 @@ class OCILargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.data.chat_response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text) # calculate num tokens prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) @@ -341,8 +359,9 @@ class OCILargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -356,14 +375,12 @@ class OCILargeLanguageModel(LargeLanguageModel): events = response.data.events() for stream in events: chunk = json.loads(stream.data) - #print(chunk) - #chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} + # print(chunk) + # chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} - - - #for chunk in response: - #for part in chunk.parts: - #if part.function_call: + # for chunk in response: + # for part in chunk.parts: + # if part.function_call: # assistant_prompt_message.tool_calls = [ # AssistantPromptMessage.ToolCall( # id=part.function_call.name, @@ -376,9 +393,7 @@ class OCILargeLanguageModel(LargeLanguageModel): # ] if "finishReason" not in chunk: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if model.startswith("cohere"): if chunk["text"]: assistant_prompt_message.content += chunk["text"] @@ -389,10 +404,7 @@ class OCILargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: # calculate num tokens @@ -409,8 +421,8 @@ class OCILargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=str(chunk["finishReason"]), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -425,17 +437,13 @@ class OCILargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -457,5 +465,5 @@ class OCILargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/oci/oci.py b/api/core/model_runtime/model_providers/oci/oci.py index 11d67790a0..e182d2d043 100644 --- a/api/core/model_runtime/model_providers/oci/oci.py +++ b/api/core/model_runtime/model_providers/oci/oci.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class OCIGENAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,14 +20,9 @@ class OCIGENAIProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `cohere.command-r-plus` model for validate, - model_instance.validate_credentials( - model='cohere.command-r-plus', - credentials=credentials - ) + model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex - - diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 5e0a85583e..80ad2be9f5 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -21,29 +21,28 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE request_template = { "compartmentId": "", - "servingMode": { - "modelId": "cohere.embed-english-light-v3.0", - "servingType": "ON_DEMAND" - }, + "servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"}, "truncate": "NONE", - "inputs": [""] + "inputs": [""], } oci_config_template = { - "user": "", - "fingerprint": "", - "tenancy": "", - "region": "", - "compartment_id": "", - "key_content": "" - } + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + class OCITextEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -62,14 +61,13 @@ class OCITextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: cutoff = int(len(text) * (np.floor(context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -80,26 +78,16 @@ class OCITextEmbeddingModel(TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=inputs[i: i + max_chunks] + model=model, credentials=credentials, texts=inputs[i : i + max_chunks] ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -125,6 +113,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): for text in texts: characters += len(text) return characters + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -135,11 +124,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -157,10 +142,13 @@ class OCITextEmbeddingModel(TextEmbeddingModel): # initialize client oci_config = copy.deepcopy(oci_config_template) if "oci_config_content" in credentials: - oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8') + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") config_items = oci_config_content.split("/") if len(config_items) != 5: - raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))") + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode(" + "'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) oci_config["user"] = config_items[0] oci_config["fingerprint"] = config_items[1] oci_config["tenancy"] = config_items[2] @@ -169,7 +157,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") if "oci_key_content" in credentials: - oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8') + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") @@ -195,10 +183,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -209,7 +194,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -224,19 +209,9 @@ class OCITextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 42a588e3dd..1ed77a2ee8 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -121,9 +121,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -145,13 +143,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {ex.description}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {str(ex)}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def _generate( self, @@ -201,9 +195,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") - data["messages"] = [ - self._convert_prompt_message_to_dict(m) for m in prompt_messages - ] + data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -216,14 +208,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel): images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) + message_content = cast(ImagePromptMessageContent, message_content) image_data = re.sub( r"^data:image\/[a-zA-Z]+;base64,", "", @@ -235,24 +223,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): data["images"] = images # send a post request to validate the credentials - response = requests.post( - endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError( - f"API request failed with status code {response.status_code}: {response.text}" - ) + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") if stream: - return self._handle_generate_stream_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) def _handle_generate_response( self, @@ -292,9 +272,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response result = LLMResult( @@ -335,9 +313,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( model=model, @@ -394,15 +370,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = chunk_json["eval_count"] else: # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2( - prompt_messages[0].content - ) + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( model=chunk_json["model"], @@ -439,17 +411,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) - image_data = re.sub( - r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -479,9 +445,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_customizable_model_schema( - self, model: str, credentials: dict - ) -> AIModelEntity: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -502,9 +466,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get("context_size", 4096) - ), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), }, parameter_rules=[ ParameterRule( @@ -568,9 +530,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): en_US="Maximum number of tokens to predict when generating text. " "(Default: 128, -1 = infinite generation, -2 = fill context)" ), - default=( - 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128 - ), + default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128), min=-2, max=int(credentials.get("max_tokens", 4096)), ), @@ -612,22 +572,23 @@ class OllamaLargeLanguageModel(LargeLanguageModel): label=I18nObject(en_US="Size of context window"), type=ParameterType.INT, help=I18nObject( - en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)" + en_US="Sets the size of the context window used to generate the next token. (Default: 2048)" ), default=2048, min=1, ), ParameterRule( - name='num_gpu', + name="num_gpu", label=I18nObject(en_US="GPU Layers"), type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to offload to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable." - "As long as a model fits into one gpu it stays in one. " - "It does not set the number of GPU(s). "), + help=I18nObject( + en_US="The number of layers to offload to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + "As long as a model fits into one gpu it stays in one. " + "It does not set the number of GPU(s). " + ), min=-1, - default=1 + default=1, ), ParameterRule( name="num_thread", @@ -678,9 +639,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel): type=ParameterType.STRING, help=I18nObject( en_US="Sets how long the model is kept in memory after generating a response. " - "This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). " - "A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. " - "Valid time units are 's','m','h'. (Default: 5m)" + "This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours)." + " A negative number keeps the model loaded indefinitely, and '0' unloads the model" + " immediately after generating a response." + " Valid time units are 's','m','h'. (Default: 5m)" ), ), ParameterRule( @@ -688,8 +650,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): label=I18nObject(en_US="Format"), type=ParameterType.STRING, help=I18nObject( - en_US="the format to return a response in." - " Currently the only accepted value is json." + en_US="the format to return a response in. Currently the only accepted value is json." ), options=["json"], ), diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py index f8a17b98a0..115280193a 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.py +++ b/api/core/model_runtime/model_providers/ollama/ollama.py @@ -6,7 +6,6 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 8f7d54c516..b4c61d8a6d 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -37,9 +37,9 @@ class OllamaEmbeddingModel(TextEmbeddingModel): Model class for an Ollama text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,15 +51,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials.get('base_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("base_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'api/embed') + endpoint_url = urljoin(endpoint_url, "api/embed") # get model properties context_size = self._get_context_size(model, credentials) @@ -67,53 +65,36 @@ class OllamaEmbeddingModel(TextEmbeddingModel): inputs = [] used_tokens = 0 - for i, text in enumerate(texts): + for text in texts: # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) # Prepare the payload for the request - payload = { - 'input': inputs, - 'model': model, - } + payload = {"input": inputs, "model": model, "options": {"use_mmap": True}} - # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300), - options={"use_mmap": "true"} - ) + # Make the request to the Ollama API + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings = response_data['embeddings'] + embeddings = response_data["embeddings"] embedding_used_tokens = self.get_num_tokens(model, credentials, inputs) used_tokens += embedding_used_tokens # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -135,19 +116,15 @@ class OllamaEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -155,15 +132,15 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -179,10 +156,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -193,7 +167,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -221,10 +195,10 @@ class OllamaEmbeddingModel(TextEmbeddingModel): ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 467a51daf2..2181bb4f08 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -22,7 +22,7 @@ class _CommonOpenAI: :return: """ credentials_kwargs = { - "api_key": credentials['openai_api_key'], + "api_key": credentials["openai_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } @@ -31,8 +31,8 @@ class _CommonOpenAI: openai_api_base = credentials["openai_api_base"].rstrip("/") credentials_kwargs["base_url"] = openai_api_base + "/v1" - if 'openai_organization' in credentials: - credentials_kwargs['organization'] = credentials['openai_organization'] + if "openai_organization" in credentials: + credentials_kwargs["organization"] = credentials["openai_organization"] return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index ac7313aaa1..7501bc1164 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -5,6 +5,10 @@ - chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 +- o1-preview +- o1-preview-2024-09-12 +- o1-mini +- o1-mini-2024-09-12 - gpt-4-turbo - gpt-4-turbo-2024-04-09 - gpt-4-turbo-preview diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml index 98e236650c..b47449a49a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml index c1602b2efc..ffa725ec40 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml index 56ab965c39..21150fc3a6 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml index 6eb15e6c0d..d3a8ee535a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml index 007cfed0f3..ac4ec5840b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml index f4fa6317af..d775239770 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml index f92173ccfd..8358425e6d 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml index 6b36361efe..0234499164 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml @@ -41,7 +41,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml index c0350ae2c6..8d29cf0c04 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml index 575acb7fa2..b25ff6a812 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml @@ -41,7 +41,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml index a63b608423..07037c6643 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml @@ -38,7 +38,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml index a7a5bf3c86..f7b5138b7d 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml index f0d835cba2..b630d6f630 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml index 7e430c51a7..73b7f69700 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml index 03e28772e6..df38270f79 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index 23dcf85085..5e3c94fbe2 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml index 4f141f772f..3090a9e090 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index dc85f7c9f2..d42fce528a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -37,18 +37,25 @@ if you are not sure about the structure. {{instructions}} -""" +""" # noqa: E501 + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,8 +71,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -80,7 +87,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -91,26 +98,34 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) # transform response format - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -123,7 +138,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) else: self._transform_completion_json_prompts( @@ -135,9 +150,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -147,14 +162,21 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -167,25 +189,35 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - - def _transform_completion_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + + def _transform_completion_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -202,25 +234,30 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): break if user_message: - if prompt_messages[i].content[-11:] == 'Assistant: ': + if prompt_messages[i].content[-11:] == "Assistant: ": # now we are in the chat app, remove the last assistant message prompt_messages[i].content = prompt_messages[i].content[:-11] prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"Assistant:\n```{response_format}\n" else: prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"\n```{response_format}\n" - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -231,8 +268,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: """ # handle fine tune remote models - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] else: base_model = model @@ -262,14 +299,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # handle fine tune remote models base_model = model # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # check if model exists remote_models = self.remote_models(credentials) remote_model_map = {model.model: model for model in remote_models} if model not in remote_model_map: - raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found') + raise CredentialsValidateFailedError(f"Fine-tuned model {model} not found") # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -277,7 +314,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if model_mode == LLMMode.CHAT: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -286,7 +323,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -313,11 +350,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # get all remote models remote_models = client.models.list() - fine_tune_models = [model for model in remote_models if model.id.startswith('ft:')] + fine_tune_models = [model for model in remote_models if model.id.startswith("ft:")] ai_model_entities = [] for model in fine_tune_models: - base_model = model.id.split(':')[1] + base_model = model.id.split(":")[1] base_model_schema = None for predefined_model_name, predefined_model in predefined_models_map.items(): @@ -329,30 +366,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ai_model_entity = AIModelEntity( model=model.id, - label=I18nObject( - zh_Hans=model.id, - en_US=model.id - ), + label=I18nObject(zh_Hans=model.id, en_US=model.id), model_type=ModelType.LLM, features=base_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=base_model_schema.model_properties, parameter_rules=base_model_schema.parameter_rules, - pricing=PriceConfig( - input=0.003, - output=0.006, - unit=0.001, - currency='USD' - ) + pricing=PriceConfig(input=0.003, output=0.006, unit=0.001, currency="USD"), ) ai_model_entities.append(ai_model_entity) return ai_model_entities - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -374,23 +410,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - "include_usage": True - } - + extra_model_kwargs["stream_options"] = {"include_usage": True} + # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -398,8 +428,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm completion response @@ -412,9 +443,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -440,8 +469,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm completion stream response @@ -451,7 +481,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" prompt_tokens = 0 completion_tokens = 0 @@ -460,8 +490,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -474,14 +504,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text or "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -494,7 +522,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -504,7 +532,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -520,10 +548,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): yield final_chunk - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -562,26 +597,34 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - 'include_usage': True - } + extra_model_kwargs["stream_options"] = {"include_usage": True} # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + block_as_stream = False + if model.startswith("o1"): + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] + + if "stop" in extra_model_kwargs: + del extra_model_kwargs["stop"] + # chat model response = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], @@ -594,11 +637,56 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + if block_as_stream: + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) + + return block_result + + def _handle_chat_block_as_stream_response( + self, + block_result: LLMResult, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[LLMResultChunk, None, None]: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :param stop: stop words + :return: llm response chunk generator + """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + + yield LLMResultChunk( + model=block_result.model, + prompt_messages=prompt_messages, + system_fingerprint=block_result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + finish_reason="stop", + usage=block_result.usage, + ), + ) + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -619,10 +707,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -648,9 +733,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -660,7 +750,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None prompt_tokens = 0 completion_tokens = 0 @@ -670,8 +760,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -685,8 +775,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -708,7 +801,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -719,12 +812,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): final_tool_calls.extend(tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -735,7 +825,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -745,7 +835,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -753,8 +843,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -764,9 +853,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -777,21 +866,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -801,14 +888,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -821,7 +905,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: cleaned prompt messages """ - checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] if model in checklist: # count how many user messages are there @@ -830,11 +914,30 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - item.data if item.type == PromptMessageContentType.TEXT else - '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' - for item in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) + + if model.startswith("o1"): + system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) + if system_message_count > 0: + new_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message = UserPromptMessage( + content=prompt_message.content, + name=prompt_message.name, + ) + + new_prompt_messages.append(prompt_message) + prompt_messages = new_prompt_messages return prompt_messages @@ -851,19 +954,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -889,11 +986,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -902,8 +995,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -924,16 +1016,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model.startswith('ft:'): - model = model.split(':')[1] + if model.startswith("ft:"): + model = model.split(":")[1] # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. - if model == "chatgpt-4o-latest": + if model == "chatgpt-4o-latest" or model.startswith("o1"): model = "gpt-4o" try: @@ -948,7 +1041,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tokens_per_message = 4 # if there's a name, the role is omitted tokens_per_name = -1 - elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): + elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"): tokens_per_message = 3 tokens_per_name = 1 else: @@ -969,10 +1062,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -1011,37 +1104,37 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) @@ -1049,26 +1142,26 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - OpenAI supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + OpenAI supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ - if not model.startswith('ft:'): + if not model.startswith("ft:"): base_model = model else: # get base_model - base_model = model.split(':')[1] + base_model = model.split(":")[1] # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} if base_model not in model_map: - raise ValueError(f'Base model {base_model} not found') - + raise ValueError(f"Base model {base_model} not found") + base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] @@ -1077,16 +1170,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml new file mode 100644 index 0000000000..07a3bc9a7a --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml @@ -0,0 +1,33 @@ +model: o1-mini-2024-09-12 +label: + zh_Hans: o1-mini-2024-09-12 + en_US: o1-mini-2024-09-12 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 65563 + min: 1 + max: 65563 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '3.00' + output: '12.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml new file mode 100644 index 0000000000..3e83529201 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml @@ -0,0 +1,33 @@ +model: o1-mini +label: + zh_Hans: o1-mini + en_US: o1-mini +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 65563 + min: 1 + max: 65563 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '3.00' + output: '12.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml new file mode 100644 index 0000000000..c9da96f611 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml @@ -0,0 +1,33 @@ +model: o1-preview-2024-09-12 +label: + zh_Hans: o1-preview-2024-09-12 + en_US: o1-preview-2024-09-12 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 32768 + min: 1 + max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '15.00' + output: '60.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml new file mode 100644 index 0000000000..c83874b765 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml @@ -0,0 +1,33 @@ +model: o1-preview +label: + zh_Hans: o1-preview + en_US: o1-preview +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 32768 + min: 1 + max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '15.00' + output: '60.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index b1d0e57ad2..619044d808 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -14,9 +14,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): Model class for OpenAI text moderation model. """ - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -34,10 +32,10 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): # chars per chunk length = self._get_max_characters_per_chunk(model, credentials) - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] max_text_chunks = self._get_max_chunks(model, credentials) - chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] + chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] for text_chunk in chunks: moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk) @@ -65,7 +63,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): self._moderation_invoke( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index 66efd4797f..175d7db73c 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -9,7 +9,6 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: """ Validate provider credentials @@ -22,12 +21,9 @@ class OpenAIProvider(ModelProvider): # Use `gpt-3.5-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index efbdd054f9..18f97e45f3 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -12,9 +12,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -37,7 +35,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index e23a2edf87..535d8388bc 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -18,9 +18,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): Model class for OpenAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +37,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" # get model properties context_size = self._get_context_size(model, credentials) @@ -56,11 +56,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -69,10 +67,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,10 +83,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -101,17 +93,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -152,17 +136,13 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -179,10 +159,12 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens @@ -197,10 +179,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -211,7 +190,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index afa5d4b88a..a14c91639b 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -14,8 +14,9 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -28,14 +29,12 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) # if streaming: - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -50,14 +49,13 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -71,31 +69,38 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] - for index, future in enumerate(futures): - yield from future.result().__enter__().iter_bytes(1024) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] + for future in futures: + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 51950ca377..1234e44f80 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,4 +1,3 @@ - import requests from core.model_runtime.errors.invoke import ( @@ -11,7 +10,7 @@ from core.model_runtime.errors.invoke import ( ) -class _CommonOAI_API_Compat: +class _CommonOaiApiCompat: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -35,10 +34,10 @@ class _CommonOAI_API_Compat: ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] - } \ No newline at end of file + requests.exceptions.ReadTimeout, # Timeout + ], + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 6279125f46..5a8a754f72 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -35,22 +35,28 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.utils import helper logger = logging.getLogger(__name__) -class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): +class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -77,8 +83,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -92,100 +103,93 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. + Validate model credentials using requests to ensure compatibility with all providers following + OpenAI's API standard. :param model: model name :param credentials: model credentials :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - endpoint_url = urljoin(endpoint_url, 'chat/completions') + endpoint_url = urljoin(endpoint_url, "chat/completions") elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - endpoint_url = urljoin(endpoint_url, 'completions') + data["prompt"] = "ping" + endpoint_url = urljoin(endpoint_url, "completions") else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''): - json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''): - json_result['object'] = 'text_completion' + if completion_type is LLMMode.CHAT and json_result.get("object", "") == "": + json_result["object"] = "chat.completion" + elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "": + json_result["object"] = "text_completion" - if (completion_type is LLMMode.CHAT - and ('object' not in json_result or json_result['object'] != 'chat.completion')): + if completion_type is LLMMode.CHAT and ( + "object" not in json_result or json_result["object"] != "chat.completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'chat.completion\'') - elif (completion_type is LLMMode.COMPLETION - and ('object' not in json_result or json_result['object'] != 'text_completion')): + "Credentials validation failed: invalid response object, must be 'chat.completion'" + ) + elif completion_type is LLMMode.COMPLETION and ( + "object" not in json_result or json_result["object"] != "text_completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'text_completion\'') + "Credentials validation failed: invalid response object, must be 'text_completion'" + ) except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ features = [] - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type in ['function_call']: + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "function_call": features.append(ModelFeature.TOOL_CALL) - elif function_calling_type in ['tool_call']: + elif function_calling_type == "tool_call": features.append(ModelFeature.MULTI_TOOL_CALL) - stream_function_calling = credentials.get('stream_function_calling', 'supported') - if stream_function_calling == 'supported': + stream_function_calling = credentials.get("stream_function_calling", "supported") + if stream_function_calling == "supported": features.append(ModelFeature.STREAM_TOOL_CALL) - vision_support = credentials.get('vision_support', 'not_support') - if vision_support == 'support': + vision_support = credentials.get("vision_support", "not_support") + if vision_support == "support": features.append(ModelFeature.VISION) entity = AIModelEntity( @@ -195,43 +199,43 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), - ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")), + ModelPropertyKey.MODE: credentials.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 0.7)), + default=float(credentials.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), + default=float(credentials.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -239,31 +243,39 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), ) - if credentials['mode'] == 'chat': + if credentials["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': + elif credentials["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") return entity - # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + # validate_credentials method has been rewritten to use the requests library for compatibility with all providers + # following OpenAI's API standard. + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -277,52 +289,47 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - extra_headers = credentials.get('extra_headers') + extra_headers = credentials.get("extra_headers") if extra_headers is not None: headers = { - **headers, - **extra_headers, + **headers, + **extra_headers, } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + if not endpoint_url.endswith("/"): + endpoint_url += "/" - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, 'completions') - data['prompt'] = prompt_messages[0].content + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -336,16 +343,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if response.status_code != 200: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") @@ -355,8 +356,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -366,11 +368,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -381,16 +384,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) # delimiter for stream response, need unicode_escape import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") @@ -406,10 +405,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = AssistantPromptMessage.ToolCall( id=tool_call_id, type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name="", - arguments="" - ) + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), ) tools_calls.append(tool_call) @@ -434,10 +430,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): chunk = chunk.strip() if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue try: @@ -447,30 +443,31 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") assistant_message_tool_calls = None - if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': - assistant_message_tool_calls = delta.get('tool_calls', None) - elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': - assistant_message_tool_calls = [{ - 'id': 'tool_call_id', - 'type': 'function', - 'function': delta.get('function_call', {}) - }] + if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": + assistant_message_tool_calls = delta.get("tool_calls", None) + elif ( + "function_call" in delta + and credentials.get("function_calling_type", "no_call") == "function_call" + ): + assistant_message_tool_calls = [ + {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} + ] # assistant_message_function_call = delta.delta.function_call @@ -479,7 +476,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message @@ -490,9 +487,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # reset tool calls tool_calls = [] full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -507,7 +504,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 @@ -518,47 +515,42 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason ) - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - + def _handle_generate_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> LLMResult: response_json = response.json() - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) - output = response_json['choices'][0] + output = response_json["choices"][0] - response_content = '' + response_content = "" tool_calls = None - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") if completion_type is LLMMode.CHAT: - response_content = output.get('message', {})['content'] - if function_calling_type == 'tool_call': - tool_calls = output.get('message', {}).get('tool_calls') - elif function_calling_type == 'function_call': - tool_calls = output.get('message', {}).get('function_call') + response_content = output.get("message", {})["content"] + if function_calling_type == "tool_call": + tool_calls = output.get("message", {}).get("tool_calls") + elif function_calling_type == "function_call": + tool_calls = output.get("message", {}).get("function_call") elif completion_type is LLMMode.COMPLETION: - response_content = output['text'] + response_content = output["text"] assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - if function_calling_type == 'tool_call': + if function_calling_type == "tool_call": assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == 'function_call': + elif function_calling_type == "function_call": assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") @@ -597,19 +589,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -618,11 +604,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] - elif function_calling_type == 'function_call': + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] + elif function_calling_type == "function_call": function_call = message.tool_calls[0] message_dict["function_call"] = { "name": function_call.function.name, @@ -633,19 +618,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } - elif function_calling_type == 'function_call': - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif function_calling_type == "function_call": + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -654,8 +631,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -667,7 +645,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if isinstance(text, str): full_text = text else: - full_text = '' + full_text = "" for message_content in text: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) @@ -680,8 +658,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: dict = None, + ) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -700,10 +683,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -741,46 +724,44 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2('type') - num_tokens += self._get_num_tokens_by_gpt2('function') - num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2("name") num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2("description") num_tokens += self._get_num_tokens_by_gpt2(tool.description) parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2('parameters') - if 'title' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2("type") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if 'properties' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): num_tokens += self._get_num_tokens_by_gpt2(key) for field_key, field_value in value.items(): num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(enum_field) else: num_tokens += self._get_num_tokens_by_gpt2(field_key) num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if 'required' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -792,20 +773,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.get("function", {}).get("name", ""), - arguments=response_tool_call.get("function", {}).get("arguments", "") + arguments=response_tool_call.get("function", {}).get("arguments", ""), ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get("id", ""), - type=response_tool_call.get("type", ""), - function=function + id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -815,14 +793,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.get('name', ''), - arguments=response_function_call.get('arguments', '') + name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.get('id', ''), - type="function", - function=function + id=response_function_call.get("id", ""), type="function", function=function ) return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 3445ebbaf7..ca6f185287 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class OAICompatProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 00702ba936..405096578c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -6,17 +6,15 @@ import requests from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): +class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel): """ Model class for OpenAI Compatible Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 363054b084..e83cfdf873 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -19,17 +19,17 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,27 +39,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -70,7 +68,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -78,7 +75,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -88,42 +85,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -145,45 +125,35 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -191,7 +161,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -199,20 +169,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -224,10 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -238,7 +204,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 8ea5819bde..34b4de7962 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -38,88 +38,115 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ - if not credentials.get('server_url'): - raise CredentialsValidateFailedError('Invalid server URL') + if not credentials.get("server_url"): + raise CredentialsValidateFailedError("Invalid server URL") # ping instance = OpenLLMGenerate() try: instance.generate( - server_url=credentials['server_url'], - model_name=model, - prompt_messages=[ - OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') - ], + server_url=credentials["server_url"], + model_name=model, + prompt_messages=[OpenLLMGenerateMessage(content="ping\nAnswer: ", role="user")], model_parameters={ - 'max_tokens': 64, - 'temperature': 0.8, - 'top_p': 0.9, - 'top_k': 15, + "max_tokens": 64, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 15, }, stream=False, - user='', + user="", stop=[], ) except InvalidAuthenticationError as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for OpenLLM model - it's a generate model, so we just join them by spe + Calculate num tokens for OpenLLM model + it's a generate model, so we just join them by spe """ - messages = ','.join([message.content for message in messages]) + messages = ",".join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( model_name=model, - server_url=credentials['server_url'], + server_url=credentials["server_url"], prompt_messages=[self._convert_prompt_message_to_openllm_message(message) for message in prompt_messages], model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessage) -> OpenLLMGenerateMessage: """ - convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface + convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface """ if isinstance(prompt_message, UserPromptMessage): return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -130,27 +157,29 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[OpenLLMGenerateMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) else: @@ -159,73 +188,55 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=message.content, tool_calls=[]), + finish_reason=message.stop_reason or None, ), ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', + name="top_k", type=ParameterType.INT, - use_template='top_k', + use_template="top_k", min=1, default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + label=I18nObject(zh_Hans="Top K", en_US="Top K"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, - parameter_rules=rules + parameter_rules=rules, ) return entity @@ -241,22 +252,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 1c3f084207..351dcced15 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -15,32 +15,38 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMGenerateMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, + self, + server_url: str, + model_name: str, + stream: bool, + model_parameters: dict[str, Any], + stop: list[str], + prompt_messages: list[OpenLLMGenerateMessage], + user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: - raise InvalidAuthenticationError('Invalid server URL') + raise InvalidAuthenticationError("Invalid server URL") default_llm_config = { "max_new_tokens": 128, @@ -72,40 +78,37 @@ class OpenLLMGenerate: "frequency_penalty": 0, "use_beam_search": False, "ignore_eos": False, - "skip_special_tokens": True + "skip_special_tokens": True, } - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - default_llm_config['max_new_tokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + default_llm_config["max_new_tokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - default_llm_config['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + default_llm_config["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - default_llm_config['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + default_llm_config["top_p"] = model_parameters["top_p"] - if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: - default_llm_config['top_k'] = model_parameters['top_k'] + if "top_k" in model_parameters and type(model_parameters["top_k"]) == int: + default_llm_config["top_k"] = model_parameters["top_k"] - if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: - default_llm_config['use_cache'] = model_parameters['use_cache'] + if "use_cache" in model_parameters and type(model_parameters["use_cache"]) == bool: + default_llm_config["use_cache"] = model_parameters["use_cache"] - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + headers = {"Content-Type": "application/json", "accept": "application/json"} if stream: - url = f'{server_url}/v1/generate_stream' + url = f"{server_url}/v1/generate_stream" timeout = 10 else: - url = f'{server_url}/v1/generate' + url = f"{server_url}/v1/generate" timeout = 120 data = { - 'stop': stop if stop else [], - 'prompt': '\n'.join([message.content for message in prompt_messages]), - 'llm_config': default_llm_config, + "stop": stop or [], + "prompt": "\n".join([message.content for message in prompt_messages]), + "llm_config": default_llm_config, } try: @@ -113,10 +116,10 @@ class OpenLLMGenerate: except (ConnectionError, InvalidSchema, MissingSchema) as e: # cloud not connect to the server raise InvalidAuthenticationError(f"Invalid server URL: {e}") - + if not response.ok: resp = response.json() - msg = resp['msg'] + msg = resp["msg"] if response.status_code == 400: raise BadRequestError(msg) elif response.status_code == 404: @@ -125,69 +128,71 @@ class OpenLLMGenerate: raise InternalServerError(msg) else: raise InternalServerError(msg) - + if stream: return self._handle_chat_stream_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_chat_generate_response(self, response: Response) -> OpenLLMGenerateMessage: try: data = response.json() except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - message = data['outputs'][0] - text = message['text'] - token_ids = message['token_ids'] - prompt_token_ids = data['prompt_token_ids'] - stop_reason = message['finish_reason'] + message = data["outputs"][0] + text = message["text"] + token_ids = message["token_ids"] + prompt_token_ids = data["prompt_token_ids"] + stop_reason = message["finish_reason"] message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) message.stop_reason = stop_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': len(token_ids), - 'total_tokens': len(prompt_token_ids) + len(token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_ids), + "total_tokens": len(prompt_token_ids) + len(token_ids), } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[OpenLLMGenerateMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[OpenLLMGenerateMessage, None, None]: completion_usage = 0 for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() - if line == '[DONE]': + if line == "[DONE]": return try: data = loads(line) except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") - - output = data['outputs'] + + output = data["outputs"] for choice in output: - text = choice['text'] - token_ids = choice['token_ids'] + text = choice["text"] + token_ids = choice["token_ids"] completion_usage += len(token_ids) message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) - if choice.get('finish_reason'): - finish_reason = choice['finish_reason'] - prompt_token_ids = data['prompt_token_ids'] + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + prompt_token_ids = data["prompt_token_ids"] message.stop_reason = finish_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': completion_usage, - 'total_tokens': completion_usage + len(prompt_token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": completion_usage, + "total_tokens": completion_usage + len(prompt_token_ids), } - - yield message \ No newline at end of file + + yield message diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6ca..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 4dbd0678e7..00e583cc79 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ Model class for OpenLLM text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,16 +36,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] if not server_url: - raise CredentialsValidateFailedError('server_url is required') - - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + raise CredentialsValidateFailedError("server_url is required") - url = f'{server_url}/v1/embeddings' + headers = {"Content-Type": "application/json", "accept": "application/json"} + + url = f"{server_url}/v1/embeddings" data = texts try: @@ -54,7 +52,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: if response.status_code == 400: raise InvokeBadRequestError(response.text) @@ -62,21 +60,17 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(response.text) elif response.status_code == 500: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json()[0] - embeddings = resp['embeddings'] - total_tokens = resp['num_tokens'] + embeddings = resp["embeddings"] + total_tokens = resp["num_tokens"] except KeyError as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -104,9 +98,9 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid server_url') + raise CredentialsValidateFailedError("Invalid server_url") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -119,23 +113,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -147,10 +131,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -161,7 +142,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml index 1737c50bb1..186c1cc663 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml @@ -26,7 +26,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml index 2d55cf8565..8c2989b300 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml @@ -41,7 +41,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml index 12015f6f64..ef19d4f6f0 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml @@ -41,7 +41,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml index cf2de0f73a..0be325f55b 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml index de0bad4136..3b1d95643d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml index 6945402c72..a8c97efdd6 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index e78ac4caf1..b6bb249a04 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -8,18 +8,22 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_credential(self, model: str, credentials: dict): - credentials['endpoint_url'] = "https://openrouter.ai/api/v1" - credentials['mode'] = self.get_model_mode(model).value - credentials['function_calling_type'] = 'tool_call' - return + credentials["endpoint_url"] = "https://openrouter.ai/api/v1" + credentials["mode"] = self.get_model_mode(model).value + credentials["function_calling_type"] = "tool_call" - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -29,9 +33,17 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, credentials) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,8 +53,13 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().get_customizable_model_schema(model, credentials) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: self._update_credential(model, credentials) return super().get_num_tokens(model, credentials, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.py b/api/core/model_runtime/model_providers/openrouter/openrouter.py index 613f71deb1..2e59ab5059 100644 --- a/api/core/model_runtime/model_providers/openrouter/openrouter.py +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.py @@ -8,17 +8,13 @@ logger = logging.getLogger(__name__) class OpenRouterProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='openai/gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="openai/gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise ex \ No newline at end of file + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py index c9116bf685..89cac665aa 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py @@ -13,11 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -27,8 +33,7 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -46,8 +51,9 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -67,10 +73,10 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -101,10 +107,10 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://cloud.perfxlab.cn' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://cloud.perfxlab.cn" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py index 0854ef5185..450d22fb75 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py +++ b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class PerfXCloudProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class PerfXCloudProvider(ModelProvider): # Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='Qwen2-72B-Instruct-GPTQ-Int4', - credentials=credentials - ) + model_instance.validate_credentials(model="Qwen2-72B-Instruct-GPTQ-Int4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 11d57e3749..b62a2d2aaf 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -19,17 +19,17 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,30 +39,28 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -73,7 +71,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -81,7 +78,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -91,42 +88,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -148,48 +128,38 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -197,7 +167,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -205,20 +175,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -230,10 +199,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -244,7 +210,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 29d8427d8e..915f6e0eef 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError class _CommonReplicate: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } + return {InvokeBadRequestError: [ReplicateError, ModelError]} diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 31b81a829e..3641b35dc0 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -28,16 +28,22 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -48,39 +54,43 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): inputs = {**model_parameters} if prompt_messages[0].role == PromptMessageRole.SYSTEM: - if 'system_prompt' in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - inputs['system_prompt'] = prompt_messages[0].content - inputs['prompt'] = prompt_messages[1].content + if "system_prompt" in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"]: + inputs["system_prompt"] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[1].content else: - inputs['prompt'] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[0].content - prediction = client.predictions.create( - version=model_info_version, input=inputs - ) + prediction = client.predictions.create(version=model_info_version, input=inputs) if stream: return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] if model.count("/") != 1: - raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' - 'format: {user_name}/{model_name}') + raise CredentialsValidateFailedError( + "Replicate Model Name must be provided, format: {user_name}/{model_name}" + ) try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -91,45 +101,44 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): self._check_text_generation_model(model_info_version, model, model_version, model_info.description) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}" + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod def _check_text_generation_model(model_info_version, model_name, version, description): - if 'language model' in description.lower(): + if "language model" in description.lower(): return - if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: + if ( + "temperature" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_p" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_k" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + ): raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - model_type = LLMMode.CHAT if model.endswith('-chat') else LLMMode.COMPLETION + model_type = LLMMode.CHAT if model.endswith("-chat") else LLMMode.COMPLETION entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: model_type.value - }, - parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) + model_properties={ModelPropertyKey.MODE: model_type.value}, + parameter_rules=self._get_customizable_model_parameter_rules(model, credentials), ) return entity @classmethod def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -140,15 +149,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): parameter_rules = [] input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt'] and 'stop' not in key: - value_type = value.get('type') + if key not in {"system_prompt", "prompt"} and "stop" not in key: + value_type = value.get("type") if not value_type: continue @@ -157,28 +164,28 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): rule = ParameterRule( name=key, - label={ - 'en_US': value['title'] - }, + label={"en_US": value["title"]}, type=param_type, help={ - 'en_US': value.get('description'), + "en_US": value.get("description"), }, required=False, - default=value.get('default'), - min=value.get('minimum'), - max=value.get('maximum') + default=value.get("default"), + min=value.get("minimum"), + max=value.get("maximum"), ) parameter_rules.append(rule) return parameter_rules - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prediction: Prediction, - stop: list[str], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> Generator: index = -1 current_completion: str = "" stop_condition_reached = False @@ -189,7 +196,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): for output in prediction.output_iterator(): current_completion += output - if not is_prediction_output_finished and prediction.status == 'succeeded': + if not is_prediction_output_finished and prediction.status == "succeeded": prediction_output_length = len(prediction.output) - 1 is_prediction_output_finished = True @@ -207,18 +214,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=output if output else '' - ) + assistant_prompt_message = AssistantPromptMessage(content=output or "") if index < prediction_output_length: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -229,15 +231,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) - def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> LLMResult: current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -255,9 +259,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): if stop_condition_reached: break - assistant_prompt_message = AssistantPromptMessage( - content=current_completion - ) + assistant_prompt_message = AssistantPromptMessage(content=current_completion) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -275,21 +277,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): @classmethod def _get_parameter_type(cls, param_type: str) -> str: - type_mapping = { - 'integer': 'int', - 'number': 'float', - 'boolean': 'boolean', - 'string': 'string' - } + type_mapping = {"integer": "int", "number": "float", "boolean": "boolean", "string": "string"} return type_mapping.get(param_type) def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/replicate/replicate.py b/api/core/model_runtime/model_providers/replicate/replicate.py index 3a5c9b84a0..ca137579c9 100644 --- a/api/core/model_runtime/model_providers/replicate/replicate.py +++ b/api/core/model_runtime/model_providers/replicate/replicate.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class ReplicateProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 0e4cdbf5bc..71b6fb99c4 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -13,32 +13,27 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - texts) + embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -47,39 +42,35 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - ['Hello worlds!']) + self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, ["Hello worlds!"] + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 4096, - 'max_chunks': 1 - } + model_properties={"context_size": 4096, "max_chunks": 1}, ) return entity @@ -90,49 +81,45 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): # sort through the openapi schema to get the name of text, texts or inputs input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for input_property in input_properties: - if input_property[0] in ('text', 'texts', 'inputs'): + if input_property[0] in {"text", "texts", "inputs"}: text_input_key = input_property[0] return text_input_key - return '' + return "" @staticmethod - def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_model_version: str, - text_input_key: str, texts: list[str]) -> list[list[float]]: - - if text_input_key in ('text', 'inputs'): + def _generate_embeddings_by_text_input_key( + client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] + ) -> list[list[float]]: + if text_input_key in {"text", "inputs"}: embeddings = [] for text in texts: - result = client.run(replicate_model_version, input={ - text_input_key: text - }) - embeddings.append(result[0].get('embedding')) + result = client.run(replicate_model_version, input={text_input_key: text}) + embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif 'texts' == text_input_key: - result = client.run(replicate_model_version, input={ - 'texts': json.dumps(texts), - "batch_size": 4, - "convert_to_numpy": False, - "normalize_embeddings": True - }) + elif "texts" == text_input_key: + result = client.run( + replicate_model_version, + input={ + "texts": json.dumps(texts), + "batch_size": 4, + "convert_to_numpy": False, + "normalize_embeddings": True, + }, + ) return result else: - raise ValueError(f'embeddings input key is invalid: {text_input_key}') + raise ValueError(f"embeddings input key is invalid: {text_input_key}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -143,7 +130,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 3d4c5825af..2edd13d56d 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -44,10 +44,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) -def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False): - """ + +def inference(predictor, messages: list[dict[str, Any]], params: dict[str, Any], stop: list, stream=False): + """ params: - predictor : Sagemaker Predictor + predictor : Sagemaker Predictor messages (List[Dict[str,Any]]): message list。 messages = [ {"role": "system", "content":"please answer in Chinese"}, @@ -55,19 +56,19 @@ def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], sto ] params (Dict[str,Any]): model parameters for LLM。 stream (bool): False by default。 - + response: result of inference if stream is False Iterator of Chunks if stream is True """ payload = { - "model" : params.get('model_name'), - "stop" : stop, + "model": params.get("model_name"), + "stop": stop, "messages": messages, - "stream" : stream, - "max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)), - "temperature" : params.get('temperature', 0.1), - "top_p" : params.get('top_p', 0.9), + "stream": stream, + "max_tokens": params.get("max_new_tokens", params.get("max_tokens", 2048)), + "temperature": params.get("temperature", 0.1), + "top_p": params.get("top_p", 0.9), } if not stream: @@ -77,36 +78,41 @@ def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], sto response_stream = predictor.predict_stream(payload) return response_stream + class SageMakerLargeLanguageModel(LargeLanguageModel): """ Model class for Cohere large language model. """ - sagemaker_client: Any = None - sagemaker_sess : Any = None - predictor : Any = None - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: bytes) -> LLMResult: + sagemaker_client: Any = None + sagemaker_sess: Any = None + predictor: Any = None + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: bytes, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - resp_obj = json.loads(resp.decode('utf-8')) - resp_str = resp_obj.get('choices')[0].get('message').get('content') + resp_obj = json.loads(resp.decode("utf-8")) + resp_str = resp_obj.get("choices")[0].get("message").get("content") if len(resp_str) == 0: raise InvokeServerUnavailableError("Empty response") - assistant_prompt_message = AssistantPromptMessage( - content=resp_str, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=resp_str, tool_calls=[]) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -118,37 +124,43 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[bytes]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[bytes], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" buffer = "" for chunk_bytes in resp: - buffer += chunk_bytes.decode('utf-8') + buffer += chunk_bytes.decode("utf-8") last_idx = 0 - for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer): + for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): try: data = json.loads(match.group(1).strip()) last_idx = match.span()[1] if "content" in data["choices"][0]["delta"]: chunk_content = data["choices"][0]["delta"]["content"] - assistant_prompt_message = AssistantPromptMessage( - content=chunk_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) - if data["choices"][0]['finish_reason'] is not None: - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + if data["choices"][0]["finish_reason"] is not None: + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -157,8 +169,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=assistant_prompt_message, - finish_reason=data["choices"][0]['finish_reason'], - usage=usage + finish_reason=data["choices"][0]["finish_reason"], + usage=usage, ), ) else: @@ -166,10 +178,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, system_fingerprint=None, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message - ), + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), ) full_response += chunk_content @@ -179,11 +188,17 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): buffer = buffer[last_idx:] - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -198,15 +213,17 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ if not self.sagemaker_client: - access_key = credentials.get('access_key') - secret_key = credentials.get('secret_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("access_key") + secret_key = credentials.get("secret_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -214,25 +231,26 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) self.predictor = Predictor( - endpoint_name=credentials.get('sagemaker_endpoint'), + endpoint_name=credentials.get("sagemaker_endpoint"), sagemaker_session=sagemaker_session, serializer=serializers.JSONSerializer(), ) - - messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ] - response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream) + messages: list[dict[str, Any]] = [{"role": p.role.value, "content": p.content} for p in prompt_messages] + response = inference( + predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream + ) if stream: if tools and len(tools) > 0: raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode") - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ @@ -247,19 +265,13 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -269,7 +281,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -282,8 +294,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): return message_dict - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -299,10 +312,10 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -339,8 +352,13 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -381,89 +399,63 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = LLMMode.value_of(credentials["mode"]).value features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 6b7cfc210b..959dff6a21 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional import boto3 @@ -20,34 +21,36 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel logger = logging.getLogger(__name__) + class SageMakerRerankModel(RerankModel): """ Model class for SageMaker rerank model. """ + sagemaker_client: Any = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -63,22 +66,21 @@ class SageMakerRerankModel(RerankModel): line = 0 try: if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -86,22 +88,20 @@ class SageMakerRerankModel(RerankModel): line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") candidate_docs = [] scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) for idx in range(len(scores)): - candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + candidate_docs.append({"content": docs[idx], "score": scores[idx]}) - sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 3 rerank_documents = [] for idx, result in enumerate(candidate_docs): rerank_document = RerankDocument( - index=idx, - text=result.get('content'), - score=result.get('score', -100.0) + index=idx, text=result.get("content"), score=result.get("score", -100.0) ) if score_threshold is not None: @@ -110,13 +110,10 @@ class SageMakerRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -137,7 +134,7 @@ class SageMakerRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -153,38 +150,24 @@ class SageMakerRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index 6f3e02489f..042155b152 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class SageMakerProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -17,27 +18,24 @@ class SageMakerProvider(ModelProvider): """ pass -def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str: - ''' - return s3_uri of this file - ''' - s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3' - s3_client.put_object( - Body=file.read(), - Bucket=bucket, - Key=s3_key, - ContentType='audio/mp3' - ) + +def buffer_to_s3(s3_client: Any, file: IO[bytes], bucket: str, s3_prefix: str) -> str: + """ + return s3_uri of this file + """ + s3_key = f"{s3_prefix}{uuid.uuid4()}.mp3" + s3_client.put_object(Body=file.read(), Bucket=bucket, Key=s3_key, ContentType="audio/mp3") return s3_key -def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str: + +def generate_presigned_url(s3_client: Any, file: IO[bytes], bucket_name: str, s3_prefix: str, expiration=600) -> str: object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix) try: - response = s3_client.generate_presigned_url('get_object', - Params={'Bucket': bucket_name, 'Key': object_key}, - ExpiresIn=expiration) + response = s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket_name, "Key": object_key}, ExpiresIn=expiration + ) except Exception as e: print(f"Error generating presigned URL: {e}") return None - return response \ No newline at end of file + return response diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 8b57f182fe..6aa8c9995f 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -19,16 +19,16 @@ from core.model_runtime.model_providers.sagemaker.sagemaker import generate_pres logger = logging.getLogger(__name__) + class SageMakerSpeech2TextModel(Speech2TextModel): """ Model class for Xinference speech to text model. """ - sagemaker_client: Any = None - s3_client : Any = None - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + sagemaker_client: Any = None + s3_client: Any = None + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -42,19 +42,20 @@ class SageMakerSpeech2TextModel(Speech2TextModel): try: if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.s3_client = boto3.client("s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) @@ -62,25 +63,21 @@ class SageMakerSpeech2TextModel(Speech2TextModel): self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - s3_prefix='dify/speech2text/' - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - bucket = credentials.get('audio_s3_cache_bucket') + s3_prefix = "dify/speech2text/" + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + bucket = credentials.get("audio_s3_cache_bucket") s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) - payload = { - "audio_s3_presign_uri" : s3_presign_url - } + payload = {"audio_s3_presign_uri": s3_presign_url} response_model = self.sagemaker_client.invoke_endpoint( - EndpointName=sagemaker_endpoint, - Body=json.dumps(payload), - ContentType="application/json" + EndpointName=sagemaker_endpoint, Body=json.dumps(payload), ContentType="application/json" ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - asr_text = json_obj['text'] + asr_text = json_obj["text"] except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") return asr_text @@ -105,38 +102,24 @@ class SageMakerSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 4b2858b1a2..d55144f8a7 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -10,21 +10,22 @@ from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel BATCH_SIZE = 20 -CONTEXT_SIZE=8192 +CONTEXT_SIZE = 8192 logger = logging.getLogger(__name__) + def batch_generator(generator, batch_size): while True: batch = list(itertools.islice(generator, batch_size)) @@ -32,33 +33,28 @@ def batch_generator(generator, batch_size): break yield batch + class SageMakerEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ + sagemaker_client: Any = None - def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list: list[str]): response_model = sm_client.invoke_endpoint( EndpointName=endpoint_name, - Body=json.dumps( - { - "inputs": content_list, - "parameters": {}, - "is_query" : False, - "instruction" : '' - } - ), + Body=json.dumps({"inputs": content_list, "parameters": {}, "is_query": False, "instruction": ""}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - embeddings = json_obj['embeddings'] + embeddings = json_obj["embeddings"] return embeddings - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -72,25 +68,27 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): try: line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") line = 3 - truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + truncated_texts = [item[:CONTEXT_SIZE] for item in texts] batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) all_embeddings = [] @@ -105,18 +103,14 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): usage = self._calc_response_usage( model=model, credentials=credentials, - tokens=0 # It's not SAAS API, usage is meaningless + tokens=0, # It's not SAAS API, usage is meaningless ) line = 6 - return TextEmbeddingResult( - embeddings=all_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -153,10 +147,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -167,7 +158,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -175,40 +166,28 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 315b31fd85..a22bd6dd6e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -22,89 +22,93 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel logger = logging.getLogger(__name__) + class TTSModelType(Enum): PresetVoice = "PresetVoice" CloneVoice = "CloneVoice" CloneVoice_CrossLingual = "CloneVoice_CrossLingual" InstructVoice = "InstructVoice" -class SageMakerText2SpeechModel(TTSModel): +class SageMakerText2SpeechModel(TTSModel): sagemaker_client: Any = None - s3_client : Any = None - comprehend_client : Any = None + s3_client: Any = None + comprehend_client: Any = None def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ pass - def _detect_lang_code(self, content:str, map_dict:dict=None): - map_dict = { - "zh" : "<|zh|>", - "en" : "<|en|>", - "ja" : "<|jp|>", - "zh-TW" : "<|yue|>", - "ko" : "<|ko|>" - } + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} response = self.comprehend_client.detect_dominant_language(Text=content) - language_code = response['Languages'][0]['LanguageCode'] + language_code = response["Languages"][0]["LanguageCode"] - return map_dict.get(language_code, '<|zh|>') + return map_dict.get(language_code, "<|zh|>") - def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): if model_type == TTSModelType.PresetVoice.value and model_role: - return { "tts_text" : content_text, "role" : model_role } + return {"tts_text": content_text, "role": model_role} if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: - return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } - if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: lang_tag = self._detect_lang_code(content_text) - return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } - if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: - return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} raise RuntimeError(f"Invalid params for {model_type}") - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -117,61 +121,55 @@ class SageMakerText2SpeechModel(TTSModel): :return: text translated to audio file """ if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.s3_client = boto3.client("s3", + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + self.comprehend_client = boto3.client( + "comprehend", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - self.comprehend_client = boto3.client('comprehend') + self.comprehend_client = boto3.client("comprehend") - model_type = credentials.get('audio_model_type', 'PresetVoice') - prompt_text = credentials.get('prompt_text') - prompt_audio = credentials.get('prompt_audio') - instruct_text = credentials.get('instruct_text') - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - payload = self._build_tts_payload( - model_type, - content_text, - voice, - prompt_text, - prompt_audio, - instruct_text - ) + model_type = credentials.get("audio_model_type", "PresetVoice") + prompt_text = credentials.get("prompt_text") + prompt_audio = credentials.get("prompt_audio") + instruct_text = credentials.get("instruct_text") + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + payload = self._build_tts_payload(model_type, content_text, voice, prompt_text, prompt_audio, instruct_text) return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -187,23 +185,11 @@ class SageMakerText2SpeechModel(TTSModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _get_model_default_voice(self, model: str, credentials: dict) -> any: @@ -219,27 +205,27 @@ class SageMakerText2SpeechModel(TTSModel): return 5 def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = 'CosyVoice' + audio_model_name = "CosyVoice" for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] - def _invoke_sagemaker(self, payload:dict, endpoint:str): + def _invoke_sagemaker(self, payload: dict, endpoint: str): response_model = self.sagemaker_client.invoke_endpoint( EndpointName=endpoint, Body=json.dumps(payload), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) return json_obj - def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any: + def _tts_invoke_streaming(self, model_type: str, payload: dict, sagemaker_endpoint: str) -> any: """ _tts_invoke_streaming text2speech model @@ -250,38 +236,40 @@ class SageMakerText2SpeechModel(TTSModel): :return: text translated to audio file """ try: - lang_tag = '' + lang_tag = "" if model_type == TTSModelType.CloneVoice_CrossLingual.value: - lang_tag = payload.pop('lang_tag') - - word_limit = self._get_model_word_limit(model='', credentials={}) + lang_tag = payload.pop("lang_tag") + + word_limit = self._get_model_word_limit(model="", credentials={}) content_text = payload.get("tts_text") if len(content_text) > word_limit: split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit) - sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ] + sentences = [f"{lang_tag}{s}" for s in split_sentences if len(s)] len_sent = len(sentences) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent)) - payloads = [ copy.deepcopy(payload) for i in range(len_sent) ] + payloads = [copy.deepcopy(payload) for i in range(len_sent)] for idx in range(len_sent): payloads[idx]["tts_text"] = sentences[idx] - futures = [ executor.submit( - self._invoke_sagemaker, - payload=payload, - endpoint=sagemaker_endpoint, - ) - for payload in payloads] + futures = [ + executor.submit( + self._invoke_sagemaker, + payload=payload, + endpoint=sagemaker_endpoint, + ) + for payload in payloads + ] - for index, future in enumerate(futures): + for future in futures: resp = future.result() - audio_bytes = requests.get(resp.get('s3_presign_url')).content + audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): - yield audio_bytes[i:i + 1024] + yield audio_bytes[i : i + 1024] else: resp = self._invoke_sagemaker(payload, sagemaker_endpoint) - audio_bytes = requests.get(resp.get('s3_presign_url')).content + audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): - yield audio_bytes[i:i + 1024] + yield audio_bytes[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index a9ce7b98c3..c1868b6ad0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 6835915816..58b033d28a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -16,39 +16,38 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel class SiliconflowRerankModel(RerankModel): - - def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') - if base_url.endswith('/'): - base_url = base_url[:-1] + base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") + base_url = base_url.removesuffix("/") try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n, - "return_documents": True - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -57,7 +56,6 @@ class SiliconflowRerankModel(RerankModel): def validate_credentials(self, model: str, credentials: dict) -> None: try: - self._invoke( model=model, credentials=credentials, @@ -68,7 +66,7 @@ class SiliconflowRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,5 +81,5 @@ class SiliconflowRerankModel(RerankModel): InvokeServerUnavailableError: [httpx.RemoteProtocolError], InvokeRateLimitError: [], InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] - } \ No newline at end of file + InvokeBadRequestError: [httpx.RequestError], + } diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index dd0eea362a..e121ab8c7e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class SiliconflowProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class SiliconflowProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py index 6ad3cab587..8d1932863e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -8,9 +8,7 @@ class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): Model class for Siliconflow Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c58765cecb..6cdf4933b4 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -10,20 +10,21 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel): """ Model class for Siliconflow text embedding model. """ + def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, texts, user) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: self._add_custom_parameters(credentials) return super().get_num_tokens(model, credentials, texts) - + @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' \ No newline at end of file + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index d57766a87a..b99a657e71 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -15,54 +15,35 @@ import websocket class SparkLLMClient: def __init__(self, model: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' + domain = "spark-api.xf-yun.com" + endpoint = "chat" if api_domain: domain = api_domain model_api_configs = { - 'spark-lite': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-pro': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-pro-128k': { - 'version': 'pro-128k', - 'chat_domain': 'pro-128k' - }, - 'spark-max': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - }, - 'spark-4.0-ultra': { - 'version': 'v4.0', - 'chat_domain': '4.0Ultra' - } + "spark-lite": {"version": "v1.1", "chat_domain": "general"}, + "spark-pro": {"version": "v3.1", "chat_domain": "generalv3"}, + "spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"}, + "spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"}, + "spark-4.0-ultra": {"version": "v4.0", "chat_domain": "4.0Ultra"}, } - api_version = model_api_configs[model]['version'] + api_version = model_api_configs[model]["version"] - self.chat_domain = model_api_configs[model]['chat_domain'] + self.chat_domain = model_api_configs[model]["chat_domain"] - if model == 'spark-pro-128k': + if model == "spark-pro-128k": self.api_base = f"wss://{domain}/{endpoint}/{api_version}" else: self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret + urlparse(self.api_base).netloc, urlparse(self.api_base).path, self.api_base, api_key, api_secret ) self.queue = queue.Queue() - self.blocking_message = '' + self.blocking_message = "" def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: # generate timestamp by RFC1123 @@ -74,33 +55,32 @@ class SparkLLMClient: signature_origin += "GET " + path + " HTTP/1.1" # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line",' + f' signature="{signature_sha_base64}"' + ) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") - v = { - "authorization": authorization, - "date": date, - "host": host - } + v = {"authorization": authorization, "date": date, "host": host} # generate url - url = api_base + '?' + urlencode(v) + url = api_base + "?" + urlencode(v) return url - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): + def run(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False): websocket.enableTrace(False) ws = websocket.WebSocketApp( self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open + on_open=self.on_open, ) ws.messages = messages ws.user_id = user_id @@ -109,86 +89,71 @@ class SparkLLMClient: ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) + self.queue.put({"status_code": error.status_code, "error": error.resp_body.decode("utf-8")}) ws.close() def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) + self.queue.put({"done": True}) def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) + self.blocking_message = "" + data = json.dumps(self.gen_params(messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs)) ws.send(data) def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) + self.queue.put({"status_code": 400, "error": f"Code: {code}, Error: {data['header']['message']}"}) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: - self.queue.put({'data': content}) + self.queue.put({"data": content}) else: self.blocking_message += content if status == 2: if not ws.streaming: - self.queue.put({'data': self.blocking_message}) + self.queue.put({"data": self.blocking_message}) ws.close() - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: + def gen_params(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None) -> dict: data = { "header": { "app_id": self.app_id, # resolve this error message => $.header.uid' length must be less or equal than 32 - "uid": user_id[:32] if user_id else None + "uid": user_id[:32] if user_id else None, }, - "parameter": { - "chat": { - "domain": self.chat_domain - } - }, - "payload": { - "message": { - "text": messages - } - } + "parameter": {"chat": {"domain": self.chat_domain}}, + "payload": {"message": {"text": messages}}, } if model_kwargs: - data['parameter']['chat'].update(model_kwargs) + data["parameter"]["chat"].update(model_kwargs) return data def subscribe(self): while True: content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") + if "error" in content: + if content["status_code"] == 401: + raise SparkError( + "[Spark] The credentials you provided are incorrect. " + "Please double-check and fill them in again." + ) + elif content["status_code"] == 403: + raise SparkError( + "[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions." + ) else: raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - if 'data' not in content: + if "data" not in content: break yield content diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 65beae517c..57193dc031 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -25,12 +25,17 @@ from ._client import SparkLLMClient class SparkLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -47,8 +52,13 @@ class SparkLargeLanguageModel(LargeLanguageModel): # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -80,15 +90,21 @@ class SparkLargeLanguageModel(LargeLanguageModel): model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -103,7 +119,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -113,21 +129,33 @@ class SparkLargeLanguageModel(LargeLanguageModel): **credentials_kwargs, ) - thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], - user, - model_parameters, - stream - )) + thread = threading.Thread( + target=client.run, + args=( + [ + {"role": prompt_message.role.value, "content": prompt_message.content} + for prompt_message in prompt_messages + ], + user, + model_parameters, + stream, + ), + ) thread.start() if stream: return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages) return self._handle_generate_response(thread, model, credentials, client, prompt_messages) - - def _handle_generate_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_generate_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -140,7 +168,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): for content in client.subscribe(): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content @@ -148,9 +176,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): thread.join() # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + assistant_prompt_message = AssistantPromptMessage(content=completion) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -168,9 +194,15 @@ class SparkLargeLanguageModel(LargeLanguageModel): ) return result - - def _handle_generate_stream_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> Generator: + + def _handle_generate_stream_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -183,12 +215,12 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ for index, content in enumerate(client.subscribe()): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else '', + content=delta or "", ) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -199,11 +231,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) thread.join() @@ -216,9 +244,9 @@ class SparkLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "app_id": credentials['app_id'], - "api_secret": credentials['api_secret'], - "api_key": credentials['api_key'], + "app_id": credentials["app_id"], + "api_secret": credentials["api_secret"], + "api_key": credentials["api_key"], } return credentials_kwargs @@ -244,7 +272,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -254,10 +282,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -277,5 +302,5 @@ class SparkLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 6f6ffc8faa..dab666e4d0 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,51 +55,51 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 1024)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.stepfun.com/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" - def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ Convert PromptMessage to dict for OpenAI API format """ @@ -106,10 +112,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -117,7 +120,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): "type": "image_url", "image_url": { "url": message_content.data, - } + }, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -127,14 +130,16 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -160,21 +165,26 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -184,11 +194,12 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -199,12 +210,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -218,9 +224,9 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -242,9 +248,9 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -253,21 +259,21 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -275,19 +281,18 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -303,26 +308,21 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py index 50b17392b5..e1c41a9153 100644 --- a/api/core/model_runtime/model_providers/stepfun/stepfun.py +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class StepfunProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class StepfunProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='step-1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="step-1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py index b62b9860cb..c3c21793e8 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py @@ -1,6 +1,7 @@ import base64 import hashlib import hmac +import operator import time import requests @@ -69,8 +70,8 @@ class FlashRecognizer: """ response: request_id string - status Integer - message String + status Integer + message String audio_duration Integer flash_result Result Array @@ -81,16 +82,16 @@ class FlashRecognizer: Sentence: text String - start_time Integer - end_time Integer - speaker_id Integer + start_time Integer + end_time Integer + speaker_id Integer word_list Word Array Word: - word String - start_time Integer - end_time Integer - stable_flag: Integer + word String + start_time Integer + end_time Integer + stable_flag: Integer """ def __init__(self, appid, credential): @@ -100,13 +101,13 @@ class FlashRecognizer: def _format_sign_string(self, param): signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/" for t in param: - if 'appid' in t: + if "appid" in t: signstr += str(t[1]) break signstr += "?" for x in param: tmp = x - if 'appid' in x: + if "appid" in x: continue for t in tmp: signstr += str(t) @@ -121,14 +122,13 @@ class FlashRecognizer: return header def _sign(self, signstr, secret_key): - hmacstr = hmac.new(secret_key.encode('utf-8'), - signstr.encode('utf-8'), hashlib.sha1).digest() + hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest() s = base64.b64encode(hmacstr) - s = s.decode('utf-8') + s = s.decode("utf-8") return s def _build_req_with_signature(self, secret_key, params, header): - query = sorted(params.items(), key=lambda d: d[0]) + query = sorted(params.items(), key=operator.itemgetter(0)) signstr = self._format_sign_string(query) signature = self._sign(signstr, secret_key) header["Authorization"] = signature @@ -138,14 +138,22 @@ class FlashRecognizer: def _create_query_arr(self, req): return { - 'appid': self.appid, 'secretid': self.credential.secret_id, 'timestamp': str(int(time.time())), - 'engine_type': req.engine_type, 'voice_format': req.voice_format, - 'speaker_diarization': req.speaker_diarization, 'hotword_id': req.hotword_id, - 'customization_id': req.customization_id, 'filter_dirty': req.filter_dirty, - 'filter_modal': req.filter_modal, 'filter_punc': req.filter_punc, - 'convert_num_mode': req.convert_num_mode, 'word_info': req.word_info, - 'first_channel_only': req.first_channel_only, 'reinforce_hotword': req.reinforce_hotword, - 'sentence_max_length': req.sentence_max_length + "appid": self.appid, + "secretid": self.credential.secret_id, + "timestamp": str(int(time.time())), + "engine_type": req.engine_type, + "voice_format": req.voice_format, + "speaker_diarization": req.speaker_diarization, + "hotword_id": req.hotword_id, + "customization_id": req.customization_id, + "filter_dirty": req.filter_dirty, + "filter_modal": req.filter_modal, + "filter_punc": req.filter_punc, + "convert_num_mode": req.convert_num_mode, + "word_info": req.word_info, + "first_channel_only": req.first_channel_only, + "reinforce_hotword": req.reinforce_hotword, + "sentence_max_length": req.sentence_max_length, } def recognize(self, req, data): diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py index 00ec5aa9c8..5b427663ca 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py @@ -18,9 +18,7 @@ from core.model_runtime.model_providers.tencent.speech2text.flash_recognizer imp class TencentSpeech2TextModel(Speech2TextModel): - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -43,7 +41,7 @@ class TencentSpeech2TextModel(Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,10 +81,6 @@ class TencentSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - requests.exceptions.ConnectionError - ], - InvokeAuthorizationError: [ - CredentialsValidateFailedError - ] + InvokeConnectionError: [requests.exceptions.ConnectionError], + InvokeAuthorizationError: [CredentialsValidateFailedError], } diff --git a/api/core/model_runtime/model_providers/tencent/tencent.py b/api/core/model_runtime/model_providers/tencent/tencent.py index dd9f90bb47..79c6f577b8 100644 --- a/api/core/model_runtime/model_providers/tencent/tencent.py +++ b/api/core/model_runtime/model_providers/tencent/tencent.py @@ -18,12 +18,9 @@ class TencentProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.SPEECH2TEXT) - model_instance.validate_credentials( - model='tencent', - credentials=credentials - ) + model_instance.validate_credentials(model="tencent", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index bb802d4071..b96d43979e 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -22,16 +22,21 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.together.xyz/v1" + credentials["endpoint_url"] = "https://api.together.xyz/v1" return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,12 +46,22 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, cred_with_endpoint) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -61,45 +76,45 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")), - ModelPropertyKey.MODE: cred_with_endpoint.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get("context_size", "4096")), + ModelPropertyKey.MODE: cred_with_endpoint.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('temperature', 0.7)), + default=float(cred_with_endpoint.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('top_p', 1)), + default=float(cred_with_endpoint.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=TOP_K, label=I18nObject(en_US="Top K"), type=ParameterType.INT, - default=int(cred_with_endpoint.get('top_k', 50)), + default=int(cred_with_endpoint.get("top_k", 50)), min=-2147483647, max=2147483647, - precision=0 + precision=0, ), ParameterRule( name=REPETITION_PENALTY, label=I18nObject(en_US="Repetition Penalty"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('repetition_penalty', 1)), + default=float(cred_with_endpoint.get("repetition_penalty", 1)), min=-3.4, max=3.4, - precision=1 + precision=1, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -107,46 +122,49 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)), + max=int(cred_with_endpoint.get("max_tokens_to_sample", 4096)), ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ], pricing=PriceConfig( - input=Decimal(cred_with_endpoint.get('input_price', 0)), - output=Decimal(cred_with_endpoint.get('output_price', 0)), - unit=Decimal(cred_with_endpoint.get('unit', 0)), - currency=cred_with_endpoint.get('currency', "USD") + input=Decimal(cred_with_endpoint.get("input_price", 0)), + output=Decimal(cred_with_endpoint.get("output_price", 0)), + unit=Decimal(cred_with_endpoint.get("unit", 0)), + currency=cred_with_endpoint.get("currency", "USD"), ), ) - if cred_with_endpoint['mode'] == 'chat': + if cred_with_endpoint["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif cred_with_endpoint['mode'] == 'completion': + elif cred_with_endpoint["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}") return entity - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index ffce4794e7..aa4100a7c9 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class TogetherAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index fab18b41fd..8a50c7aa05 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -21,7 +21,7 @@ class _CommonTongyi: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: credentials_kwargs = { - "dashscope_api_key": credentials['dashscope_api_key'], + "dashscope_api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -51,5 +51,5 @@ class _CommonTongyi: InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 6667d40440..1d4eba6668 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -4,6 +4,7 @@ import tempfile import uuid from collections.abc import Generator from http import HTTPStatus +from pathlib import Path from typing import Optional, Union, cast from dashscope import Generation, MultiModalConversation, get_tokenizer @@ -45,11 +46,17 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class TongyiLargeLanguageModel(LargeLanguageModel): tokenizers = {} - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -65,8 +72,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ # invoke model without code wrapper return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -76,10 +89,10 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: """ - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') - if model == 'farui-plus': - model = 'qwen-farui-plus' + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: + model = model.replace("-chat", "") + if model == "farui-plus": + model = "qwen-farui-plus" if model in self.tokenizers: tokenizer = self.tokenizers[model] @@ -110,16 +123,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -138,18 +157,18 @@ class TongyiLargeLanguageModel(LargeLanguageModel): mode = self.get_model_mode(model, credentials) - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: + model = model.replace("-chat", "") extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = self._convert_tools(tools) + extra_model_kwargs["tools"] = self._convert_tools(tools) if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop params = { - 'model': model, + "model": model, **model_parameters, **credentials_kwargs, **extra_model_kwargs, @@ -157,23 +176,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): model_schema = self.get_model_schema(model, credentials) if ModelFeature.VISION in (model_schema.features or []): - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) response = MultiModalConversation.call(**params, stream=stream) else: # nothing different between chat model and completion model in tongyi - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) - response = Generation.call(**params, - result_format='message', - stream=stream) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + response = Generation.call(**params, result_format="message", stream=stream) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -183,10 +201,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - if response.status_code != 200 and response.status_code != HTTPStatus.OK: - raise ServiceUnavailableError( - response.message - ) + if response.status_code not in {200, HTTPStatus.OK}: + raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=response.output.choices[0].message.content, @@ -205,9 +221,13 @@ class TongyiLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, - responses: Generator[GenerationResponse, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + responses: Generator[GenerationResponse, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -217,10 +237,10 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" tool_calls = [] for index, response in enumerate(responses): - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError( f"Failed to invoke model {model}, status code: {response.status_code}, " f"message: {response.message}" @@ -228,22 +248,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): resp_finish_reason = response.output.choices[0].finish_reason - if resp_finish_reason is not None and resp_finish_reason != 'null': + if resp_finish_reason is not None and resp_finish_reason != "null": resp_content = response.output.choices[0].message.content assistant_prompt_message = AssistantPromptMessage( - content='', + content="", ) - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] elif resp_content: # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message - assistant_prompt_message.content = resp_content.replace(full_text, '', 1) + assistant_prompt_message.content = resp_content.replace(full_text, "", 1) full_text = resp_content @@ -251,12 +271,11 @@ class TongyiLargeLanguageModel(LargeLanguageModel): message_tool_calls = [] for tool_call_obj in tool_calls: message_tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_obj['function']['name'], - type='function', + id=tool_call_obj["function"]["name"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call_obj['function']['name'], - arguments=tool_call_obj['function']['arguments'] - ) + name=tool_call_obj["function"]["name"], arguments=tool_call_obj["function"]["arguments"] + ), ) message_tool_calls.append(message_tool_call) @@ -270,26 +289,23 @@ class TongyiLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=resp_finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=resp_finish_reason, usage=usage + ), ) else: resp_content = response.output.choices[0].message.content if not resp_content: - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] continue # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=resp_content.replace(full_text, '', 1), + content=resp_content.replace(full_text, "", 1), ) full_text = resp_content @@ -297,10 +313,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -311,7 +324,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "api_key": credentials['dashscope_api_key'], + "api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -338,9 +351,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): break elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") @@ -356,16 +367,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage], - rich_content: bool = False) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages( + self, prompt_messages: list[PromptMessage], rich_content: bool = False + ) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -375,24 +384,28 @@ class TongyiLargeLanguageModel(LargeLanguageModel): tongyi_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): - tongyi_messages.append({ - 'role': 'system', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "system", + "content": prompt_message.content if not rich_content else [{"text": prompt_message.content}], + } + ) elif isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, str): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "user", + "content": prompt_message.content + if not rich_content + else [{"text": prompt_message.content}], + } + ) else: sub_messages = [] for message_content in prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -402,35 +415,25 @@ class TongyiLargeLanguageModel(LargeLanguageModel): # convert image base64 data to file in /tmp image_url = self._save_base64_image_to_file(message_content.data) - sub_message_dict = { - "image": image_url - } + sub_message_dict = {"image": image_url} sub_messages.append(sub_message_dict) # resort sub_messages to ensure text is always at last - sub_messages = sorted(sub_messages, key=lambda x: 'text' in x) + sub_messages = sorted(sub_messages, key=lambda x: "text" in x) - tongyi_messages.append({ - 'role': 'user', - 'content': sub_messages - }) + tongyi_messages.append({"role": "user", "content": sub_messages}) elif isinstance(prompt_message, AssistantPromptMessage): content = prompt_message.content if not content: - content = ' ' - message = { - 'role': 'assistant', - 'content': content if not rich_content else [{"text": content}] - } + content = " " + message = {"role": "assistant", "content": content if not rich_content else [{"text": content}]} if prompt_message.tool_calls: - message['tool_calls'] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] + message["tool_calls"] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] tongyi_messages.append(message) elif isinstance(prompt_message, ToolPromptMessage): - tongyi_messages.append({ - "role": "tool", - "content": prompt_message.content, - "name": prompt_message.tool_call_id - }) + tongyi_messages.append( + {"role": "tool", "content": prompt_message.content, "name": prompt_message.tool_call_id} + ) else: raise ValueError(f"Got unknown type {prompt_message}") @@ -445,15 +448,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :return: image file path """ # get mime type and encoded string - mime_type, encoded_string = base64_image.split(',')[0].split(';')[0].split(':')[1], base64_image.split(',')[1] + mime_type, encoded_string = base64_image.split(",")[0].split(";")[0].split(":")[1], base64_image.split(",")[1] # save image to file temp_dir = tempfile.gettempdir() file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}") - with open(file_path, "wb") as image_file: - image_file.write(base64.b64decode(encoded_string)) + Path(file_path).write_bytes(base64.b64decode(encoded_string)) return f"file://{file_path}" @@ -463,19 +465,18 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ tool_definitions = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] properties_definitions = {} for p_key, p_val in properties.items(): - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: [{', '.join(p_val['enum'])}]" properties_definitions[p_key] = { - 'description': desc, - 'type': p_val['type'], + "description": desc, + "type": p_val["type"], } tool_definition = { @@ -484,8 +485,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): "name": tool.name, "description": tool.description, "parameters": properties_definitions, - "required": required_properties - } + "required": required_properties, + }, } tool_definitions.append(tool_definition) @@ -517,5 +518,5 @@ class TongyiLargeLanguageModel(LargeLanguageModel): InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 97dcb72f7c..5783d2e383 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -46,7 +46,6 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -71,12 +70,8 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=used_tokens - ) - return TextEmbeddingResult( - embeddings=batched_embeddings, usage=usage, model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -108,16 +103,12 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents( - credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] - ) + self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents( - credentials_kwargs: dict, model: str, texts: list[str] - ) -> tuple[list[list[float]], int]: + def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -145,7 +136,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): raise ValueError("Embedding data is missing in the response.") else: raise ValueError("Response output is missing or does not contain embeddings.") - + if response.usage and "total_tokens" in response.usage: embedding_used_tokens += response.usage["total_tokens"] else: @@ -153,9 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def _calc_response_usage( - self, model: str, credentials: dict, tokens: int - ) -> EmbeddingUsage: + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.py b/api/core/model_runtime/model_providers/tongyi/tongyi.py index d5e25e6ecf..a084512de9 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.py +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.py @@ -20,12 +20,9 @@ class TongyiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `qwen-turbo` model for validate, - model_instance.validate_credentials( - model='qwen-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="qwen-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 664b02cd92..48a38897a8 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -18,8 +18,9 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -31,14 +32,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -53,14 +52,13 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -82,15 +80,21 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): else: sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl)) for sentence in sentences: - SpeechSynthesizer.call(model=v, sample_rate=16000, - api_key=api_key, - text=sentence.strip(), - callback=cb, - format=at, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) + SpeechSynthesizer.call( + model=v, + sample_rate=16000, + api_key=api_key, + text=sentence.strip(), + callback=cb, + format=at, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) - threading.Thread(target=invoke_remote, args=( - content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start() + threading.Thread( + target=invoke_remote, + args=(content_text, voice, credentials.get("dashscope_api_key"), callback, audio_type, word_limit), + ).start() while True: audio = audio_queue.get() @@ -112,16 +116,18 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param audio_type: audio file type :return: text translated to audio file """ - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - api_key=credentials.get('dashscope_api_key'), - text=sentence.strip(), - format=audio_type) + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, + sample_rate=48000, + api_key=credentials.get("dashscope_api_key"), + text=sentence.strip(), + format=audio_type, + ) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() class Callback(ResultCallback): - def __init__(self, queue: Queue): self._queue = queue diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py index 95272a41c2..cf7e3f14be 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -33,198 +33,223 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class TritonInferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={}, stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={}, + stream=False, + ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during connection: {str(ex)}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause TritonInference LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): - text += f'User: {item.content}' + text += f"User: {item.content}" elif isinstance(item, SystemPromptMessage): - text += f'System: {item.content}' + text += f"System: {item.content}" elif isinstance(item, AssistantPromptMessage): - text += f'Assistant: {item.content}' + text += f"Assistant: {item.content}" else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=int(credentials.get('context_length', 2048)), - default=min(512, int(credentials.get('context_length', 2048))), - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + max=int(credentials.get("context_length", 2048)), + default=min(512, int(credentials.get("context_length", 2048))), + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), parameter_rules=rules, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_length", 2048)), }, ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if 'stream' in credentials and not bool(credentials['stream']) and stream: - raise ValueError(f'stream is not supported by model {model}') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + + if "stream" in credentials and not bool(credentials["stream"]) and stream: + raise ValueError(f"stream is not supported by model {model}") try: parameters = {} - if 'temperature' in model_parameters: - parameters['temperature'] = model_parameters['temperature'] - if 'top_p' in model_parameters: - parameters['top_p'] = model_parameters['top_p'] - if 'top_k' in model_parameters: - parameters['top_k'] = model_parameters['top_k'] - if 'presence_penalty' in model_parameters: - parameters['presence_penalty'] = model_parameters['presence_penalty'] - if 'frequency_penalty' in model_parameters: - parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + if "temperature" in model_parameters: + parameters["temperature"] = model_parameters["temperature"] + if "top_p" in model_parameters: + parameters["top_p"] = model_parameters["top_p"] + if "top_k" in model_parameters: + parameters["top_k"] = model_parameters["top_k"] + if "presence_penalty" in model_parameters: + parameters["presence_penalty"] = model_parameters["presence_penalty"] + if "frequency_penalty" in model_parameters: + parameters["frequency_penalty"] = model_parameters["frequency_penalty"] - response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ - 'text_input': self._convert_prompt_message_to_text(prompt_messages), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'parameters': { - 'stream': False, - **parameters + response = post( + str(URL(credentials["server_url"]) / "v2" / "models" / model / "generate"), + json={ + "text_input": self._convert_prompt_message_to_text(prompt_messages), + "max_tokens": model_parameters.get("max_tokens", 512), + "parameters": {"stream": False, **parameters}, }, - }, timeout=(10, 120)) + timeout=(10, 120), + ) response.raise_for_status() if response.status_code != 200: - raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') - + raise InvokeBadRequestError(f"Invoke failed with status code {response.status_code}, {response.text}") + if stream: - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) except Exception as ex: - raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') - - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> LLMResult: + raise InvokeConnectionError(f"An error occurred during connection: {str(ex)}") + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) usage.completion_tokens = self._get_num_tokens_by_gpt2(text) return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text - ), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> Generator: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -233,13 +258,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=text - ), - usage=usage - ) + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text), usage=usage), ) @property @@ -253,15 +272,9 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - ], - InvokeRateLimitError: [ - ], - InvokeAuthorizationError: [ - ], - InvokeBadRequestError: [ - ValueError - ] - } \ No newline at end of file + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ValueError], + } diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py index 06846825ab..d85f7c82e7 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -4,6 +4,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class XinferenceAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 13b73181e9..47ebaccd84 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping import openai @@ -20,13 +19,13 @@ class _CommonUpstage: Transform credentials to kwargs for model instance :param credentials: - :return: + :return: """ credentials_kwargs = { - "api_key": credentials['upstage_api_key'], + "api_key": credentials["upstage_api_key"], "base_url": "https://api.upstage.ai/v1/solar", "timeout": Timeout(315.0, read=300.0, write=20.0, connect=10.0), - "max_retries": 1 + "max_retries": 1, } return credentials_kwargs @@ -53,5 +52,3 @@ class _CommonUpstage: openai.APIError, ], } - - diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index d1ed4619d6..a18ee90624 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -34,17 +34,25 @@ if you are not sure about the structure. {{instructions}} -""" +""" # noqa: E501 + class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ - Model class for Upstage large language model. + Model class for Upstage large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -67,15 +75,25 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, - model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] self._transform_chat_json_prompts( model=model, @@ -86,9 +104,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -98,15 +116,23 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ - Transform json prompts + Transform json prompts """ if stop is None: stop = [] @@ -117,20 +143,29 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): prompt_messages[0] = SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=UPSTAGE_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: - prompt_messages.insert(0, SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=UPSTAGE_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -155,30 +190,31 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): client = OpenAI(**credentials_kwargs) client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=10, - stream=False + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False ) except Exception as e: raise CredentialsValidateFailedError(str(e)) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) extra_model_kwargs = {} if tools: - extra_model_kwargs["functions"] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: extra_model_kwargs["stop"] = stop @@ -198,10 +234,15 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -222,10 +263,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -251,9 +289,14 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -263,7 +306,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None prompt_tokens = 0 completion_tokens = 0 @@ -273,8 +316,8 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -288,8 +331,11 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -311,7 +357,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -322,12 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): final_tool_calls.extend(tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -338,7 +381,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -348,7 +391,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -356,8 +399,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -367,9 +409,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -380,21 +422,19 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -404,14 +444,11 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -429,19 +466,13 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -467,11 +498,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -483,16 +510,17 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for solar with Huggingface Solar tokenizer. - Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer + Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer """ tokenizer = self._get_tokenizer() - tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> - tokens_prefix = 1 # <|startoftext|> - tokens_suffix = 3 # <|im_start|>assistant\n + tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> + tokens_prefix = 1 # <|startoftext|> + tokens_suffix = 3 # <|im_start|>assistant\n num_tokens = 0 num_tokens += tokens_prefix @@ -502,10 +530,10 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "tool_calls": @@ -538,37 +566,37 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(tokenizer.encode('type')) - num_tokens += len(tokenizer.encode('function')) + num_tokens += len(tokenizer.encode("type")) + num_tokens += len(tokenizer.encode("function")) # calculate num tokens for function object - num_tokens += len(tokenizer.encode('name')) + num_tokens += len(tokenizer.encode("name")) num_tokens += len(tokenizer.encode(tool.name)) - num_tokens += len(tokenizer.encode('description')) + num_tokens += len(tokenizer.encode("description")) num_tokens += len(tokenizer.encode(tool.description)) parameters = tool.parameters - num_tokens += len(tokenizer.encode('parameters')) - if 'title' in parameters: - num_tokens += len(tokenizer.encode('title')) + num_tokens += len(tokenizer.encode("parameters")) + if "title" in parameters: + num_tokens += len(tokenizer.encode("title")) num_tokens += len(tokenizer.encode(parameters.get("title"))) - num_tokens += len(tokenizer.encode('type')) + num_tokens += len(tokenizer.encode("type")) num_tokens += len(tokenizer.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(tokenizer.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(tokenizer.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(tokenizer.encode(key)) for field_key, field_value in value.items(): num_tokens += len(tokenizer.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(tokenizer.encode(enum_field)) else: num_tokens += len(tokenizer.encode(field_key)) num_tokens += len(tokenizer.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(tokenizer.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(tokenizer.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(tokenizer.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 05ae8665d6..edd4a36d98 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -18,6 +18,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): """ Model class for Upstage text embedding model. """ + def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") @@ -53,9 +54,9 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): for i, text in enumerate(texts): token = tokenizer.encode(text, add_special_tokens=False).tokens for j in range(0, len(token), context_size): - tokens += [token[j:j+context_size]] + tokens += [token[j : j + context_size]] indices += [i] - + batched_embeddings = [] _iter = range(0, len(tokens), max_chunks) @@ -63,20 +64,20 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i:i+max_chunks], + texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch - + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) - + for i in range(len(texts)): _result = results[i] if len(_result) == 0: @@ -91,15 +92,11 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() - - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: tokenizer = self._get_tokenizer() """ @@ -122,7 +119,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): total_num_tokens += len(tokenized_text) return total_num_tokens - + def validate_credentials(self, model: str, credentials: Mapping) -> None: """ Validate model credentials @@ -137,16 +134,13 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model :param model: model name @@ -155,17 +149,19 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): :param extra_model_kwargs: extra model kwargs :return: embeddings and used tokens """ - response = client.embeddings.create( - model=model, - input=texts, - **extra_model_kwargs - ) + response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs) + + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": + return ( + [ + list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) + for embedding in response.data + ], + response.usage.total_tokens, + ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': - return ([list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) for embedding in response.data], response.usage.total_tokens) - return [data.embedding for data in response.data], response.usage.total_tokens - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -176,10 +172,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): :return: usage """ input_price_info = self.get_price( - model=model, - credentials=credentials, - tokens=tokens, - price_type=PriceType.INPUT + model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT ) usage = EmbeddingUsage( @@ -189,7 +182,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/upstage.py b/api/core/model_runtime/model_providers/upstage/upstage.py index 56c91c0061..e45d4aae19 100644 --- a/api/core/model_runtime/model_providers/upstage/upstage.py +++ b/api/core/model_runtime/model_providers/upstage/upstage.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class UpstageProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,14 +18,10 @@ class UpstageProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model="solar-1-mini-chat", - credentials=credentials - ) + model_instance.validate_credentials(model="solar-1-mini-chat", credentials=credentials) except CredentialsValidateFailedError as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e except Exception as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e - diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index af6ec3937c..da69b7cdf3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -5,7 +5,6 @@ import logging from collections.abc import Generator from typing import Optional, Union, cast -import google.api_core.exceptions as exceptions import google.auth.transport.requests import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream @@ -17,6 +16,7 @@ from anthropic.types import ( MessageStopEvent, MessageStreamEvent, ) +from google.api_core import exceptions from google.cloud import aiplatform from google.oauth2 import service_account from PIL import Image @@ -49,12 +49,17 @@ logger = logging.getLogger(__name__) class VertexAiLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -74,8 +79,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # invoke Gemini model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate_anthropic( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke Anthropic large language model @@ -92,7 +105,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] - token = '' + token = "" # get access token from service account credential if service_account_info: @@ -101,41 +114,35 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): credentials.refresh(request) token = credentials.token - # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region - if 'opus' or 'claude-3-5-sonnet' in model: - location = 'us-east5' + # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available + # in us-central1 region + if "opus" in model or "claude-3-5-sonnet" in model: + location = "us-east5" else: - location = 'us-central1' - + location = "us-central1" + # use access token to authenticate if token: - client = AnthropicVertex( - region=location, - project_id=project_id, - access_token=token - ) - # When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable + client = AnthropicVertex(region=location, project_id=project_id, access_token=token) + # When access token is empty, try to use the Google Cloud VM's built-in service account + # or the GOOGLE_APPLICATION_CREDENTIALS environment variable else: client = AnthropicVertex( - region=location, + region=location, project_id=project_id, ) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system response = client.messages.create( - model=model, - messages=prompt_message_dicts, - stream=stream, - **model_parameters, - **extra_model_kwargs + model=model, messages=prompt_message_dicts, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -143,8 +150,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return self._handle_claude_response(model, credentials, response, prompt_messages) - def _handle_claude_response(self, model: str, credentials: dict, response: Message, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_claude_response( + self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -156,9 +164,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.content[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.content[0].text) # calculate num tokens if response.usage: @@ -175,16 +181,18 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_claude_stream_response( + self, + model: str, + credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -196,7 +204,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -217,18 +225,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='' - ), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text or "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text or "", ) index = chunk.index yield LLMResultChunk( @@ -237,12 +243,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) except Exception as ex: raise InvokeError(str(ex)) - def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_claude_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -262,10 +270,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -281,7 +286,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -295,13 +300,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: @@ -323,10 +328,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -336,7 +338,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -344,17 +346,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) @@ -370,8 +370,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -384,7 +389,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -394,13 +399,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -416,14 +418,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -435,20 +439,25 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -462,7 +471,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop @@ -494,26 +503,21 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): else: history.append(content) - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - google_model = glm.GenerativeModel( - model_name=model, - system_instruction=system_instruction - ) + google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, - generation_config=glm.GenerationConfig( - **config_kwargs - ), + generation_config=glm.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, - tools=self._convert_tools_to_glm_tool(tools) if tools else None + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -521,8 +525,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -533,9 +538,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.candidates[0].content.parts[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.candidates[0].content.parts[0].text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -554,8 +557,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -568,9 +572,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): index = -1 for chunk in response: for part in chunk.candidates[0].content.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -579,35 +581,31 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - - if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + + if not hasattr(chunk, "finish_reason") or not chunk.finish_reason: # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -615,8 +613,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=chunk.candidates[0].finish_reason, - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -631,17 +629,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -658,7 +652,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) - if (isinstance(message.content, str)): + if isinstance(message.content, str): glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) else: parts = [] @@ -666,8 +660,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if c.type == PromptMessageContentType.TEXT: parts.append(glm.Part.from_text(c.data)) else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) glm_content = glm.Content(role="user", parts=parts) return glm_content @@ -675,22 +669,33 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if message.content: glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) if message.tool_calls: - glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))]) + glm_content = glm.Content( + role="model", + parts=[ + glm.Part.from_function_response( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ], + ) return glm_content elif isinstance(message, ToolPromptMessage): - glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))]) + glm_content = glm.Content( + role="function", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], + ) return glm_content else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -702,25 +707,20 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: Invoke emd = gml.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -736,5 +736,5 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 2404ba5894..519373a7f3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -29,9 +29,9 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): Model class for Vertex AI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,23 +51,12 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): client = VertexTextEmbeddingModel.from_pretrained(model) - embeddings_batch, embedding_used_tokens = self._embedding_invoke( - client=client, - texts=texts - ) + embeddings_batch, embedding_used_tokens = self._embedding_invoke(client=client, texts=texts) # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=embedding_used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=embedding_used_tokens) - return TextEmbeddingResult( - embeddings=embeddings_batch, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings_batch, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -115,15 +104,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): client = VertexTextEmbeddingModel.from_pretrained(model) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'] - ) + self._embedding_invoke(model=model, client=client, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore """ Invoke embedding model @@ -154,10 +139,7 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -168,14 +150,14 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -183,15 +165,15 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py index 3cbfb088d1..466a86fd36 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -20,12 +20,9 @@ class VertexAiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-1.0-pro-002` model for validate, - model_instance.validate_credentials( - model='gemini-1.0-pro-002', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-1.0-pro-002", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index a4d89dabcb..cfe21e4b9f 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -69,31 +69,26 @@ class ArkClientV3: def from_credentials(cls, credentials): """Initialize the client using the credentials provided.""" args = { - "base_url": credentials['api_endpoint_host'], - "region": credentials['volc_region'], + "base_url": credentials["api_endpoint_host"], + "region": credentials["volc_region"], } if credentials.get("auth_method") == "api_key": args = { **args, - "api_key": credentials['volc_api_key'], + "api_key": credentials["volc_api_key"], } else: args = { **args, - "ak": credentials['volc_access_key_id'], - "sk": credentials['volc_secret_access_key'], + "ak": credentials["volc_access_key_id"], + "sk": credentials["volc_secret_access_key"], } if cls.is_compatible_with_legacy(credentials): - args = { - **args, - "base_url": DEFAULT_V3_ENDPOINT - } + args = {**args, "base_url": DEFAULT_V3_ENDPOINT} - client = ArkClientV3( - **args - ) - client.endpoint_id = credentials['endpoint_id'] + client = ArkClientV3(**args) + client.endpoint_id = credentials["endpoint_id"] return client @staticmethod @@ -107,54 +102,48 @@ class ArkClientV3: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - content.append(ChatCompletionContentPartTextParam( - text=message_content.text, - type='text', - )) + content.append( + ChatCompletionContentPartTextParam( + text=message_content.text, + type="text", + ) + ) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append(ChatCompletionContentPartImageParam( - image_url=ImageURL( - url=image_data, - detail=message_content.detail.value, - ), - type='image_url', - )) - message_dict = ChatCompletionUserMessageParam( - role='user', - content=content - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type="image_url", + ) + ) + message_dict = ChatCompletionUserMessageParam(role="user", content=content) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = ChatCompletionAssistantMessageParam( content=message.content, - role='assistant', - tool_calls=None if not message.tool_calls else [ + role="assistant", + tool_calls=None + if not message.tool_calls + else [ ChatCompletionMessageToolCallParam( id=call.id, - function=Function( - name=call.function.name, - arguments=call.function.arguments - ), - type='function' - ) for call in message.tool_calls - ] + function=Function(name=call.function.name, arguments=call.function.arguments), + type="function", + ) + for call in message.tool_calls + ], ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = ChatCompletionSystemMessageParam( - content=message.content, - role='system' - ) + message_dict = ChatCompletionSystemMessageParam(content=message.content, role="system") elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = ChatCompletionToolMessageParam( - content=message.content, - role='tool', - tool_call_id=message.tool_call_id + content=message.content, role="tool", tool_call_id=message.tool_call_id ) else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -164,23 +153,25 @@ class ArkClientV3: @staticmethod def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: return ChatCompletionToolParam( - type='function', + type="function", function=FunctionDefinition( name=message.name, description=message.description, parameters=message.parameters, - ) + ), ) - def chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> ChatCompletion: + def chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: """Block chat""" return self.ark.chat.completions.create( model=self.endpoint_id, @@ -194,15 +185,17 @@ class ArkClientV3: temperature=temperature, ) - def stream_chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> Generator[ChatCompletionChunk]: + def stream_chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: """Stream chat""" chunks = self.ark.chat.completions.create( stream=True, @@ -215,11 +208,9 @@ class ArkClientV3: presence_penalty=presence_penalty, top_p=top_p, temperature=temperature, + stream_options={"include_usage": True}, ) - for chunk in chunks: - if not chunk.choices: - continue - yield chunk + yield from chunks def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse: return self.ark.embeddings.create(model=self.endpoint_id, input=texts) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 1978c11680..266f1216f8 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService class MaaSClient(MaasService): @@ -25,12 +25,12 @@ class MaaSClient(MaasService): self.endpoint_id = endpoint_id @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] + def from_credential(cls, credentials: dict) -> "MaaSClient": + host = credentials["api_endpoint_host"] + region = credentials["volc_region"] + ak = credentials["volc_access_key_id"] + sk = credentials["volc_secret_access_key"] + endpoint_id = credentials["endpoint_id"] client = cls(host, region) client.set_endpoint_id(endpoint_id) @@ -40,8 +40,8 @@ class MaaSClient(MaasService): def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + "parameters": params, + "messages": [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], **extra_model_kwargs, } if not stream: @@ -55,9 +55,7 @@ class MaaSClient(MaasService): ) def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } + req = {"input": texts} return super().embeddings(self.endpoint_id, req) @staticmethod @@ -65,49 +63,40 @@ class MaaSClient(MaasService): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + message_dict = {"role": ChatRole.USER, "content": message.content} else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + raise ValueError("Content object type only support image_url") elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + { + "type": "image_url", + "image_url": { + "url": "", + "image_bytes": image_data, + "detail": message_content.detail, + }, } - }) + ) - message_dict = {'role': ChatRole.USER, 'content': content} + message_dict = {"role": ChatRole.USER, "content": content} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} + message_dict = {"role": ChatRole.ASSISTANT, "content": message.content} if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict["tool_calls"] = [ + {"name": call.function.name, "arguments": call.function.arguments} for call in message.tool_calls ] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = {"role": ChatRole.SYSTEM, "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = {"role": ChatRole.FUNCTION, "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -117,7 +106,7 @@ class MaaSClient(MaasService): def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: try: resp = fn() - except MaasException as e: + except MaasError as e: raise wrap_error(e) return resp @@ -130,5 +119,5 @@ class MaaSClient(MaasService): "name": tool.name, "description": tool.description, "parameters": tool.parameters, - } + }, } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 21ffaf1258..91dbe21a61 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,144 +1,144 @@ -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError -class ClientSDKRequestError(MaasException): +class ClientSDKRequestError(MaasError): pass -class SignatureDoesNotMatch(MaasException): +class SignatureDoesNotMatchError(MaasError): pass -class RequestTimeout(MaasException): +class RequestTimeoutError(MaasError): pass -class ServiceConnectionTimeout(MaasException): +class ServiceConnectionTimeoutError(MaasError): pass -class MissingAuthenticationHeader(MaasException): +class MissingAuthenticationHeaderError(MaasError): pass -class AuthenticationHeaderIsInvalid(MaasException): +class AuthenticationHeaderIsInvalidError(MaasError): pass -class InternalServiceError(MaasException): +class InternalServiceError(MaasError): pass -class MissingParameter(MaasException): +class MissingParameterError(MaasError): pass -class InvalidParameter(MaasException): +class InvalidParameterError(MaasError): pass -class AuthenticationExpire(MaasException): +class AuthenticationExpireError(MaasError): pass -class EndpointIsInvalid(MaasException): +class EndpointIsInvalidError(MaasError): pass -class EndpointIsNotEnable(MaasException): +class EndpointIsNotEnableError(MaasError): pass -class ModelNotSupportStreamMode(MaasException): +class ModelNotSupportStreamModeError(MaasError): pass -class ReqTextExistRisk(MaasException): +class ReqTextExistRiskError(MaasError): pass -class RespTextExistRisk(MaasException): +class RespTextExistRiskError(MaasError): pass -class EndpointRateLimitExceeded(MaasException): +class EndpointRateLimitExceededError(MaasError): pass -class ServiceConnectionRefused(MaasException): +class ServiceConnectionRefusedError(MaasError): pass -class ServiceConnectionClosed(MaasException): +class ServiceConnectionClosedError(MaasError): pass -class UnauthorizedUserForEndpoint(MaasException): +class UnauthorizedUserForEndpointError(MaasError): pass -class InvalidEndpointWithNoURL(MaasException): +class InvalidEndpointWithNoURLError(MaasError): pass -class EndpointAccountRpmRateLimitExceeded(MaasException): +class EndpointAccountRpmRateLimitExceededError(MaasError): pass -class EndpointAccountTpmRateLimitExceeded(MaasException): +class EndpointAccountTpmRateLimitExceededError(MaasError): pass -class ServiceResourceWaitQueueFull(MaasException): +class ServiceResourceWaitQueueFullError(MaasError): pass -class EndpointIsPending(MaasException): +class EndpointIsPendingError(MaasError): pass -class ServiceNotOpen(MaasException): +class ServiceNotOpenError(MaasError): pass AuthErrors = { - 'SignatureDoesNotMatch': SignatureDoesNotMatch, - 'MissingAuthenticationHeader': MissingAuthenticationHeader, - 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, - 'AuthenticationExpire': AuthenticationExpire, - 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, + "SignatureDoesNotMatch": SignatureDoesNotMatchError, + "MissingAuthenticationHeader": MissingAuthenticationHeaderError, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError, + "AuthenticationExpire": AuthenticationExpireError, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError, } BadRequestErrors = { - 'MissingParameter': MissingParameter, - 'InvalidParameter': InvalidParameter, - 'EndpointIsInvalid': EndpointIsInvalid, - 'EndpointIsNotEnable': EndpointIsNotEnable, - 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, - 'ReqTextExistRisk': ReqTextExistRisk, - 'RespTextExistRisk': RespTextExistRisk, - 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, - 'ServiceNotOpen': ServiceNotOpen, + "MissingParameter": MissingParameterError, + "InvalidParameter": InvalidParameterError, + "EndpointIsInvalid": EndpointIsInvalidError, + "EndpointIsNotEnable": EndpointIsNotEnableError, + "ModelNotSupportStreamMode": ModelNotSupportStreamModeError, + "ReqTextExistRisk": ReqTextExistRiskError, + "RespTextExistRisk": RespTextExistRiskError, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError, + "ServiceNotOpen": ServiceNotOpenError, } RateLimitErrors = { - 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, - 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, - 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, + "EndpointRateLimitExceeded": EndpointRateLimitExceededError, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError, } ServerUnavailableErrors = { - 'InternalServiceError': InternalServiceError, - 'EndpointIsPending': EndpointIsPending, - 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, + "InternalServiceError": InternalServiceError, + "EndpointIsPending": EndpointIsPendingError, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError, } ConnectionErrors = { - 'ClientSDKRequestError': ClientSDKRequestError, - 'RequestTimeout': RequestTimeout, - 'ServiceConnectionTimeout': ServiceConnectionTimeout, - 'ServiceConnectionRefused': ServiceConnectionRefused, - 'ServiceConnectionClosed': ServiceConnectionClosed, + "ClientSDKRequestError": ClientSDKRequestError, + "RequestTimeout": RequestTimeoutError, + "ServiceConnectionTimeout": ServiceConnectionTimeoutError, + "ServiceConnectionRefused": ServiceConnectionRefusedError, + "ServiceConnectionClosed": ServiceConnectionClosedError, } ErrorCodeMap = { @@ -150,7 +150,7 @@ ErrorCodeMap = { } -def wrap_error(e: MaasException) -> Exception: +def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 64f342f16e..8b3eb157be 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole -from .maas import MaasException, MaasService +from .maas import MaasError, MaasService -__all__ = ['MaasService', 'ChatRole', 'MaasException'] +__all__ = ["MaasService", "ChatRole", "MaasError"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 053432a089..c22bf8e76d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -1,5 +1,6 @@ # coding : utf-8 import datetime +from itertools import starmap import pytz @@ -8,12 +9,12 @@ from .util import Util class MetaData: def __init__(self): - self.algorithm = '' - self.credential_scope = '' - self.signed_headers = '' - self.date = '' - self.region = '' - self.service = '' + self.algorithm = "" + self.credential_scope = "" + self.signed_headers = "" + self.date = "" + self.region = "" + self.service = "" def set_date(self, date): self.date = date @@ -36,23 +37,23 @@ class MetaData: class SignResult: def __init__(self): - self.xdate = '' - self.xCredential = '' - self.xAlgorithm = '' - self.xSignedHeaders = '' - self.xSignedQueries = '' - self.xSignature = '' - self.xContextSha256 = '' - self.xSecurityToken = '' + self.xdate = "" + self.xCredential = "" + self.xAlgorithm = "" + self.xSignedHeaders = "" + self.xSignedQueries = "" + self.xSignature = "" + self.xContextSha256 = "" + self.xSecurityToken = "" - self.authorization = '' + self.authorization = "" def __str__(self): - return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) + return "\n".join(list(starmap("{}:{}".format, self.__dict__.items()))) class Credentials: - def __init__(self, ak, sk, service, region, session_token=''): + def __init__(self, ak, sk, service, region, session_token=""): self.ak = ak self.sk = sk self.service = service @@ -72,73 +73,87 @@ class Credentials: class Signer: @staticmethod def sign(request, credentials): - if request.path == '': - request.path = '/' - if request.method != 'GET' and not ('Content-Type' in request.headers): - request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + if request.path == "": + request.path = "/" + if request.method != "GET" and "Content-Type" not in request.headers: + request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8" format_date = Signer.get_current_format_date() - request.headers['X-Date'] = format_date - if credentials.session_token != '': - request.headers['X-Security-Token'] = credentials.session_token + request.headers["X-Date"] = format_date + if credentials.session_token != "": + request.headers["X-Security-Token"] = credentials.session_token md = MetaData() - md.set_algorithm('HMAC-SHA256') + md.set_algorithm("HMAC-SHA256") md.set_service(credentials.service) md.set_region(credentials.region) md.set_date(format_date[:8]) hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) - md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) + md.set_credential_scope("/".join([md.date, md.region, md.service, "request"])) - signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_str = "\n".join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) - request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) - return + request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) @staticmethod def hashed_canonical_request_v4(request, meta): body_hash = Util.sha256(request.body) - request.headers['X-Content-Sha256'] = body_hash + request.headers["X-Content-Sha256"] = body_hash signed_headers = {} for key in request.headers: - if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): + if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] - if 'host' in signed_headers: - v = signed_headers['host'] - if v.find(':') != -1: - split = v.split(':') + if "host" in signed_headers: + v = signed_headers["host"] + if v.find(":") != -1: + split = v.split(":") port = split[1] - if str(port) == '80' or str(port) == '443': - signed_headers['host'] = split[0] + if str(port) == "80" or str(port) == "443": + signed_headers["host"] = split[0] - signed_str = '' + signed_str = "" for key in sorted(signed_headers.keys()): - signed_str += key + ':' + signed_headers[key] + '\n' + signed_str += key + ":" + signed_headers[key] + "\n" - meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) + meta.set_signed_headers(";".join(sorted(signed_headers.keys()))) - canonical_request = '\n'.join( - [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, - meta.signed_headers, body_hash]) + canonical_request = "\n".join( + [ + request.method, + Util.norm_uri(request.path), + Util.norm_query(request.query), + signed_str, + meta.signed_headers, + body_hash, + ] + ) return Util.sha256(canonical_request) @staticmethod def get_signing_secret_key_v4(sk, date, region, service): - date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) + date = Util.hmac_sha256(bytes(sk, encoding="utf-8"), date) region = Util.hmac_sha256(date, region) service = Util.hmac_sha256(region, service) - return Util.hmac_sha256(service, 'request') + return Util.hmac_sha256(service, "request") @staticmethod def build_auth_header_v4(signature, meta, credentials): - credential = credentials.ak + '/' + meta.credential_scope - return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature + credential = credentials.ak + "/" + meta.credential_scope + return ( + meta.algorithm + + " Credential=" + + credential + + ", SignedHeaders=" + + meta.signed_headers + + ", Signature=" + + signature + ) @staticmethod def get_current_format_date(): - return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") + return datetime.datetime.now(tz=pytz.timezone("UTC")).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py index 7271ae63fd..33c41f3eb3 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py @@ -6,7 +6,7 @@ import requests from .auth import Signer -VERSION = 'v1.0.137' +VERSION = "v1.0.137" class Service: @@ -31,7 +31,7 @@ class Service: self.service_info.scheme = scheme def get(self, api, params, doseq=0): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] @@ -40,52 +40,61 @@ class Service: Signer.sign(r, self.service_info.credentials) url = r.build(doseq) - resp = self.session.get(url, headers=r.headers, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.get( + url, headers=r.headers, timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout) + ) if resp.status_code == 200: return resp.text else: raise Exception(resp.text) def post(self, api, params, form): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/x-www-form-urlencoded' + r.headers["Content-Type"] = "application/x-www-form-urlencoded" r.form = self.merge(api_info.form, form) r.body = urlencode(r.form, True) Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.form, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return resp.text else: raise Exception(resp.text) def json(self, api, params, body): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/json' + r.headers["Content-Type"] = "application/json" r.body = body Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.body, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return json.dumps(resp.json()) else: raise Exception(resp.text.encode("utf-8")) def put(self, url, file_path, headers): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: resp = self.session.put(url, headers=headers, data=f) if resp.status_code == 200: return True, resp.text.encode("utf-8") @@ -105,7 +114,7 @@ class Service: params[key] = str(params[key]) elif type(params[key]) == list: if not doseq: - params[key] = ','.join(params[key]) + params[key] = ",".join(params[key]) connection_timeout = self.service_info.connection_timeout socket_timeout = self.service_info.socket_timeout @@ -117,8 +126,8 @@ class Service: r.set_socket_timeout(socket_timeout) headers = self.merge(api_info.header, self.service_info.header) - headers['Host'] = self.service_info.host - headers['User-Agent'] = 'volc-sdk-python/' + VERSION + headers["Host"] = self.service_info.host + headers["User-Agent"] = "volc-sdk-python/" + VERSION r.set_headers(headers) query = self.merge(api_info.query, params) @@ -143,13 +152,13 @@ class Service: class Request: def __init__(self): - self.schema = '' - self.method = '' - self.host = '' - self.path = '' + self.schema = "" + self.method = "" + self.host = "" + self.path = "" self.headers = OrderedDict() self.query = OrderedDict() - self.body = '' + self.body = "" self.form = {} self.connection_timeout = 0 self.socket_timeout = 0 @@ -182,11 +191,11 @@ class Request: self.socket_timeout = socket_timeout def build(self, doseq=0): - return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) + return self.schema + "://" + self.host + self.path + "?" + urlencode(self.query, doseq) class ServiceInfo: - def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme="http"): self.host = host self.header = header self.credentials = credentials @@ -204,4 +213,4 @@ class ApiInfo: self.header = header def __str__(self): - return 'method: ' + self.method + ', path: ' + self.path + return "method: " + self.method + ", path: " + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py index 7eb5fdfa91..178d63714e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py @@ -1,5 +1,6 @@ import hashlib import hmac +import operator from functools import reduce from urllib.parse import quote @@ -7,28 +8,28 @@ from urllib.parse import quote class Util: @staticmethod def norm_uri(path): - return quote(path).replace('%2F', '/').replace('+', '%20') + return quote(path).replace("%2F", "/").replace("+", "%20") @staticmethod def norm_query(params): - query = '' + query = "" for key in sorted(params.keys()): if type(params[key]) == list: for k in params[key]: - query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&" else: - query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(params[key], safe="-_.~") + "&" query = query[:-1] - return query.replace('+', '%20') + return query.replace("+", "%20") @staticmethod def hmac_sha256(key, content): - return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() + return hmac.new(key, bytes(content, encoding="utf-8"), hashlib.sha256).digest() @staticmethod def sha256(content): if isinstance(content, str) is True: - return hashlib.sha256(content.encode('utf-8')).hexdigest() + return hashlib.sha256(content.encode("utf-8")).hexdigest() else: return hashlib.sha256(content).hexdigest() @@ -36,8 +37,8 @@ class Util: def to_hex(content): lst = [] for ch in content: - hv = hex(ch).replace('0x', '') + hv = hex(ch).replace("0x", "") if len(hv) == 1: - hv = '0' + hv + hv = "0" + hv lst.append(hv) - return reduce(lambda x, y: x + y, lst) + return reduce(operator.add, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py index 8b14d026d9..3825fd6574 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py @@ -43,9 +43,7 @@ def json_to_object(json_str, req_id=None): def gen_req_id(): - return datetime.now().strftime("%Y%m%d%H%M%S") + format( - random.randint(0, 2 ** 64 - 1), "020X" - ) + return datetime.now().strftime("%Y%m%d%H%M%S") + format(random.randint(0, 2**64 - 1), "020X") class SSEDecoder: @@ -53,13 +51,13 @@ class SSEDecoder: self.source = source def _read(self): - data = b'' + data = b"" for chunk in self.source: for line in chunk.splitlines(True): data += line - if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): yield data - data = b'' + data = b"" if data: yield data @@ -67,13 +65,13 @@ class SSEDecoder: for chunk in self._read(): for line in chunk.splitlines(): # skip comment - if line.startswith(b':'): + if line.startswith(b":"): continue - if b':' in line: - field, value = line.split(b':', 1) + if b":" in line: + field, value = line.split(b":", 1) else: - field, value = line, b'' + field, value = line, b"" - if field == b'data' and len(value) > 0: + if field == b"data" and len(value) > 0: yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py index 3cbe9d9f09..a3836685f1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -9,9 +9,7 @@ from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object class MaasService(Service): def __init__(self, host, region, connection_timeout=60, socket_timeout=60): - service_info = self.get_service_info( - host, region, connection_timeout, socket_timeout - ) + service_info = self.get_service_info(host, region, connection_timeout, socket_timeout) self._apikey = None api_info = self.get_api_info() super().__init__(service_info, api_info) @@ -35,9 +33,7 @@ class MaasService(Service): def get_api_info(): api_info = { "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), - "embeddings": ApiInfo( - "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} - ), + "embeddings": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}), } return api_info @@ -52,9 +48,7 @@ class MaasService(Service): try: req["stream"] = True - res = self._call( - endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True - ) + res = self._call(endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True) decoder = SSEDecoder(res) @@ -64,13 +58,12 @@ class MaasService(Service): return try: - res = json_to_object( - str(data, encoding="utf-8"), req_id=req_id) + res = json_to_object(str(data, encoding="utf-8"), req_id=req_id) except Exception: raise if res.error is not None and res.error.code_n != 0: - raise MaasException( + raise MaasError( res.error.code_n, res.error.code, res.error.message, @@ -79,7 +72,7 @@ class MaasService(Service): yield res return iter_fn() - except MaasException: + except MaasError: raise except Exception as e: raise new_client_sdk_request_error(str(e)) @@ -95,29 +88,28 @@ class MaasService(Service): apikey = self._apikey try: - res = self._call(endpoint_id, api, req_id, params, - json.dumps(req).encode("utf-8"), apikey) + res = self._call(endpoint_id, api, req_id, params, json.dumps(req).encode("utf-8"), apikey) resp = dict_to_object(res.json()) if resp and isinstance(resp, dict): resp["req_id"] = req_id return resp - except MaasException as e: + except MaasError as e: raise e except Exception as e: raise new_client_sdk_request_error(str(e), req_id) def _validate(self, api, req_id): credentials_exist = ( - self.service_info.credentials is not None and - self.service_info.credentials.sk is not None and - self.service_info.credentials.ak is not None + self.service_info.credentials is not None + and self.service_info.credentials.sk is not None + and self.service_info.credentials.ak is not None ) if not self._apikey and not credentials_exist: raise new_client_sdk_request_error("no valid credential", req_id) - if not (api in self.api_info): + if api not in self.api_info: raise new_client_sdk_request_error("no such api", req_id) def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): @@ -150,22 +142,19 @@ class MaasService(Service): raw = res.text.encode() res.close() try: - resp = json_to_object( - str(raw, encoding="utf-8"), req_id=req_id) + resp = json_to_object(str(raw, encoding="utf-8"), req_id=req_id) except Exception: raise new_client_sdk_request_error(raw, req_id) if resp.error: - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, req_id - ) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id) else: raise new_client_sdk_request_error(resp, req_id) return res -class MaasException(Exception): +class MaasError(Exception): def __init__(self, code_n, code, message, req_id): self.code_n = code_n self.code = code @@ -173,15 +162,17 @@ class MaasException(Exception): self.req_id = req_id def __str__(self): - return ("Detailed exception information is listed below.\n" + - "req_id: {}\n" + - "code_n: {}\n" + - "code: {}\n" + - "message: {}").format(self.req_id, self.code_n, self.code, self.message) + return ( + "Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}" + ).format(self.req_id, self.code_n, self.code, self.message) def new_client_sdk_request_error(raw, req_id=""): - return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) class BinaryResponseContent: @@ -189,25 +180,19 @@ class BinaryResponseContent: self.response = response self.request_id = request_id - def stream_to_file( - self, - file: str - ) -> None: + def stream_to_file(self, file: str) -> None: is_first = True - error_bytes = b'' + error_bytes = b"" with open(file, mode="wb") as f: for data in self.response: - if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): + if len(error_bytes) > 0 or (is_first and '"error":' in str(data)): error_bytes += data else: f.write(data) if len(error_bytes) > 0: - resp = json_to_object( - str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, self.request_id - ) + resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) def iter_bytes(self) -> Iterator[bytes]: yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 996c66e604..dec6c9d789 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -49,10 +49,17 @@ logger = logging.getLogger(__name__) class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: if ArkClientV3.is_legacy(credentials): return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -71,27 +78,36 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): try: client.chat( { - 'max_new_tokens': 16, - 'temperature': 0.7, - 'top_p': 0.9, - 'top_k': 15, + "max_new_tokens": 16, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 15, }, - [UserPromptMessage(content='ping\nAnswer: ')], + [UserPromptMessage(content="ping\nAnswer: ")], ) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) @staticmethod def _validate_credentials_v3(credentials: dict) -> None: client = ArkClientV3.from_credentials(credentials) try: - client.chat(max_tokens=16, temperature=0.7, top_p=0.9, - messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + client.chat( + max_tokens=16, + temperature=0.7, + top_p=0.9, + messages=[UserPromptMessage(content="ping\nAnswer: ")], + ) except Exception as e: raise CredentialsValidateFailedError(e) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if ArkClientV3.is_legacy(credentials): return self._get_num_tokens_v2(prompt_messages) return self._get_num_tokens_v3(prompt_messages) @@ -100,8 +116,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + messages_dict = [MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -113,8 +128,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - ArkClientV3.convert_prompt_message(m) for m in messages] + messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -122,118 +136,126 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): return num_tokens - def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v2( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [ - MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools - ] - resp = MaaSClient.wrap_exception( - lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) + extra_model_kwargs["tools"] = [MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools] + resp = MaaSClient.wrap_exception(lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) def _handle_stream_chat_response() -> Generator: for index, r in enumerate(resp): - choices = r['choices'] + choices = r["choices"] if not choices: continue choice = choices[0] - message = choice['message'] + message = choice["message"] usage = None - if r.get('usage'): - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=r['usage']['prompt_tokens'], - completion_tokens=r['usage']['completion_tokens'] - ) + if r.get("usage"): + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=r["usage"]["prompt_tokens"], + completion_tokens=r["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] - ), + message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]), usage=usage, - finish_reason=choice.get('finish_reason'), + finish_reason=choice.get("finish_reason"), ), ) def _handle_chat_response() -> LLMResult: - choices = resp['choices'] + choices = resp["choices"] if not choices: raise ValueError("No choices found") choice = choices[0] - message = choice['message'] + message = choice["message"] # parse tool calls tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: + if message["tool_calls"]: + for call in message["tool_calls"]: tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], + id=call["function"]["name"], + type=call["type"], function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + name=call["function"]["name"], arguments=call["function"]["arguments"] + ), ) tool_calls.append(tool_call) - usage = resp['usage'] + usage = resp["usage"] return LLMResult( model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', + content=message["content"] or "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], + ), ) if not stream: return _handle_chat_response() return _handle_stream_chat_response() - def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v3( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = ArkClientV3.from_credentials(credentials) req_params = get_v3_req_params(credentials, model_parameters, stop) if tools: - req_params['tools'] = tools + req_params["tools"] = tools def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: for chunk in chunks: - if not chunk.choices: - continue - choice = chunk.choices[0] - yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=choice.index, + index=0, message=AssistantPromptMessage( - content=choice.delta.content, - tool_calls=[] + content=chunk.choices[0].delta.content if chunk.choices else "", tool_calls=[] ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens - ) if chunk.usage else None, - finish_reason=choice.finish_reason, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.usage + else None, + finish_reason=chunk.choices[0].finish_reason if chunk.choices else None, ), ) @@ -248,9 +270,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): id=call.id, type=call.type, function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call.function.name, - arguments=call.function.arguments - ) + name=call.function.name, arguments=call.function.arguments + ), ) tool_calls.append(tool_call) @@ -259,13 +280,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message.content if message.content else "", + content=message.content or "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ), ) if not stream: @@ -277,72 +300,56 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ model_config = get_model_config(credentials) rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', - type=ParameterType.INT, - min=1, - default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + name="top_k", type=ParameterType.INT, min=1, default=1, label=I18nObject(zh_Hans="Top K", en_US="Top K") ), ParameterRule( - name='presence_penalty', + name="presence_penalty", type=ParameterType.FLOAT, - use_template='presence_penalty', + use_template="presence_penalty", label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", type=ParameterType.FLOAT, - use_template='frequency_penalty', + use_template="frequency_penalty", label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=model_config.properties.max_tokens, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ] @@ -352,9 +359,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index a882f68a36..d8be14b024 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -16,138 +16,127 @@ class ModelConfig(BaseModel): configs: dict[str, ModelConfig] = { - 'Doubao-pro-4k': ModelConfig( + "Doubao-pro-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-4k': ModelConfig( + "Doubao-lite-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-32k': ModelConfig( + "Doubao-pro-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-32k': ModelConfig( + "Doubao-lite-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-128k': ModelConfig( + "Doubao-pro-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Doubao-lite-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Skylark2-pro-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Llama3-8B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-8B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Llama3-70B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-70B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Moonshot-v1-8k': ModelConfig( + "Moonshot-v1-8k": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-32k': ModelConfig( + "Moonshot-v1-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-128k': ModelConfig( + "Moonshot-v1-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B': ModelConfig( + "GLM3-130B": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B-Fin': ModelConfig( + "GLM3-130B-Fin": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], + ), + "Mistral-7B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[] ), - 'Mistral-7B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), - features=[] - ) } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = configs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_tokens=int(credentials.get('max_tokens', 0)), - mode=LLMMode.value_of(credentials.get('mode', 'chat')), + context_size=int(credentials.get("context_size", 0)), + max_tokens=int(credentials.get("max_tokens", 0)), + mode=LLMMode.value_of(credentials.get("mode", "chat")), ), - features=[] + features=[], ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_prompt_tokens'] = model_configs.properties.context_size - req_params['max_new_tokens'] = model_configs.properties.max_tokens + req_params["max_prompt_tokens"] = model_configs.properties.context_size + req_params["max_new_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_new_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("top_k"): + req_params["top_k"] = model_parameters.get("top_k") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params -def get_v3_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_tokens'] = model_configs.properties.max_tokens + req_params["max_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 74cf26247c..ce4f0c3ab1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -11,20 +11,18 @@ class ModelConfig(BaseModel): ModelConfigs = { - 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=32) - ), + "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_chunks=int(credentials.get('max_chunks', 0)), + context_size=int(credentials.get("context_size", 0)), + max_chunks=int(credentials.get("max_chunks", 0)), ) ) return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index d54aeeb0b1..9cba2cb879 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -40,9 +40,9 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): Model class for VolcengineMaaS text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -57,37 +57,27 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): return self._generate_v3(model, credentials, texts, user) - def _generate_v2(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v2( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp["usage"]["total_tokens"]) - result = TextEmbeddingResult( - model=model, - embeddings=[v['embedding'] for v in resp['data']], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v["embedding"] for v in resp["data"]], usage=usage) return result - def _generate_v3(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v3( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = ArkClientV3.from_credentials(credentials) resp = client.embeddings(texts) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp.usage.total_tokens) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp.usage.total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=[v.embedding for v in resp.data], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v.embedding for v in resp.data], usage=usage) return result @@ -120,13 +110,13 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) - except MaasException as e: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except MaasError as e: raise CredentialsValidateFailedError(e.message) def _validate_credentials_v3(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: raise CredentialsValidateFailedError(e) @@ -150,12 +140,12 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ model_config = get_model_config(credentials) model_properties = { ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, - ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks, } entity = AIModelEntity( model=model, @@ -165,10 +155,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): model_properties=model_properties, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -184,10 +174,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -198,7 +185,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index 017856bdde..d72d1bd83a 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -11,7 +11,7 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import ( RateLimitReachedError, ) -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, "BaiduAccessToken"] = {} baidu_access_tokens_lock = Lock() @@ -22,49 +22,46 @@ class BaiduAccessToken: def __init__(self, api_key: str) -> None: self.api_key = api_key - self.access_token = '' + self.access_token = "" self.expires = datetime.now() + timedelta(days=3) @staticmethod def _get_access_token(api_key: str, secret_key: str) -> str: """ - request access token from Baidu + request access token from Baidu """ try: response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, + url=f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}", + headers={"Content-Type": "application/json", "Accept": "application/json"}, ) except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + raise InvalidAuthenticationError(f"Failed to get access token from Baidu: {e}") resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': + if "error" in resp: + if resp["error"] == "invalid_client": raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': + elif resp["error"] == "unknown_error": raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': + elif resp["error"] == "invalid_request": raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': + elif resp["error"] == "rate_limit_exceeded": raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') else: raise Exception(f'Unknown error: {resp["error_description"]}') - return resp['access_token'] + return resp["access_token"] @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + def get_access_token(api_key: str, secret_key: str) -> "BaiduAccessToken": """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. """ # loop up cache, remove expired access token @@ -98,49 +95,49 @@ class BaiduAccessToken: class _CommonWenxin: api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', - 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', - 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', - 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', + "ernie-bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-bot-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-3.5-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-3.5-8k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205", + "ernie-3.5-8k-1222": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222", + "ernie-3.5-4k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-3.5-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k", + "ernie-4.0-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-4.0-8k-latest": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-speed-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed", + "ernie-speed-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k", + "ernie-speed-appbuilder": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas", + "ernie-lite-8k-0922": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-lite-8k-0308": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k", + "ernie-character-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview", + "yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat", + "embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1", + "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", + "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", + "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", } function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' + "ernie-bot", + "ernie-bot-8k", + "ernie-3.5-8k", + "ernie-3.5-8k-0205", + "ernie-3.5-8k-1222", + "ernie-3.5-4k-0205", + "ernie-3.5-128k", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview", + "yi_34b_chat", ] - api_key: str = '' - secret_key: str = '' + api_key: str = "" + secret_key: str = "" def __init__(self, api_key: str, secret_key: str): self.api_key = api_key @@ -148,10 +145,7 @@ class _CommonWenxin: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - credentials_kwargs = { - "api_key": credentials['api_key'], - "secret_key": credentials['secret_key'] - } + credentials_kwargs = {"api_key": credentials["api_key"], "secret_key": credentials["secret_key"]} return credentials_kwargs def _handle_error(self, code: int, msg: str): @@ -187,13 +181,13 @@ class _CommonWenxin: 336105: BadRequestError, 336200: InternalServerError, 336303: BadRequestError, - 337006: BadRequestError + 337006: BadRequestError, } if code in error_map: raise error_map[code](msg) else: - raise InternalServerError(f'Unknown error: {msg}') + raise InternalServerError(f"Unknown error: {msg}") def _get_access_token(self) -> str: token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 8109949b1d..07b970f810 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -15,33 +15,39 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import ( class ErnieMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - FUNCTION = 'function' - SYSTEM = 'system' + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - def __init__(self, content: str, role: str = 'user') -> None: + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role + class ErnieBotModel(_CommonWenxin): - - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ - stop: list[str], user: str) \ - -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: - + def generate( + self, + model: str, + stream: bool, + messages: list[ErnieMessage], + parameters: dict[str, Any], + timeout: int, + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters self._check_parameters(model, parameters, tools, stop) @@ -49,22 +55,23 @@ class ErnieBotModel(_CommonWenxin): access_token = self._get_access_token() # generate request body - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" # clone messages messages_cloned = self._copy_messages(messages=messages) # build body - body = self._build_request_body(model, messages=messages_cloned, stream=stream, - parameters=parameters, tools=tools, stop=stop, user=user) + body = self._build_request_body( + model, messages=messages_cloned, stream=stream, parameters=parameters, tools=tools, stop=stop, user=user + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url=url, data=dumps(body), headers=headers, stream=stream) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") if stream: return self._handle_chat_stream_generate_response(resp) @@ -73,10 +80,11 @@ class ErnieBotModel(_CommonWenxin): def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str]) -> None: + def _check_parameters( + self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str] + ) -> None: if model not in self.api_bases: - raise BadRequestError(f'Invalid model: {model}') + raise BadRequestError(f"Invalid model: {model}") # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') @@ -85,86 +93,106 @@ class ErnieBotModel(_CommonWenxin): # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError('function calling is not supported yet.') + raise BadRequestError("function calling is not supported yet.") if stop is not None: if len(stop) > 4: - raise BadRequestError('stop list should not exceed 4 items.') + raise BadRequestError("stop list should not exceed 4 items.") for s in stop: if len(s) > 20: - raise BadRequestError('stop item should not exceed 20 characters.') + raise BadRequestError("stop item should not exceed 20 characters.") - def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: + def _build_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], - stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_function_calling_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role == 'function': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role == "function": + raise BadRequestError("The first message should be user message.") """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_chat_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError('The number of messages should not be zero.') + raise BadRequestError("The number of messages should not be zero.") # check if the first element is system, shift it - system_message = '' - if messages[0].role == 'system': + system_message = "" + if messages[0].role == "system": message = messages.pop(0) system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role != 'user': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role != "user": + raise BadRequestError("The first message should be user message.") body = { - 'messages': [message.to_dict() for message in messages], - 'stream': stream, - 'stop': stop, - 'user_id': user, - **parameters + "messages": [message.to_dict() for message in messages], + "stream": stream, + "stop": stop, + "user_id": user, + **parameters, } - if 'max_tokens' in parameters and type(parameters['max_tokens']) == int: - body['max_output_tokens'] = parameters['max_tokens'] + if "max_tokens" in parameters and type(parameters["max_tokens"]) == int: + body["max_output_tokens"] = parameters["max_tokens"] - if 'presence_penalty' in parameters and type(parameters['presence_penalty']) == float: - body['penalty_score'] = parameters['presence_penalty'] + if "presence_penalty" in parameters and type(parameters["presence_penalty"]) == float: + body["penalty_score"] = parameters["presence_penalty"] if system_message: - body['system'] = system_message + body["system"] = system_message return body def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - result = data['result'] - usage = data['usage'] + result = data["result"] + usage = data["usage"] - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } return message @@ -173,19 +201,19 @@ class ErnieBotModel(_CommonWenxin): for line in response.iter_lines(): if len(line) == 0: continue - line = line.decode('utf-8') - if line[0] == '{': + line = line.decode("utf-8") + if line[0] == "{": try: data = loads(line) - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - if line.startswith('data:'): + if line.startswith("data:"): line = line[5:].strip() else: continue @@ -195,23 +223,23 @@ class ErnieBotModel(_CommonWenxin): try: data = loads(line) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - result = data['result'] - is_end = data['is_end'] + result = data["result"] + is_end = data["is_end"] if is_end: - usage = data['usage'] - finish_reason = data.get('finish_reason', None) - message = ErnieMessage(content=result, role='assistant') + usage = data["usage"] + finish_reason = data.get("finish_reason", None) + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } message.stop_reason = finish_reason yield message else: - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") yield message diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 140606298c..f7c160b6b4 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -28,44 +28,84 @@ if you are not sure about the structure. You should also complete the text started with ``` but not tell ``` directly. -""" +""" # noqa: E501 + class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: - response_format = model_parameters['response_format'] + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: + response_format = model_parameters["response_format"] stop = stop or [] - self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) - model_parameters.pop('response_format') + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format + ) + model_parameters.pop("response_format") if stream: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + input_generator=self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), ) - + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts to model prompts """ @@ -74,34 +114,44 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last message prompt_messages[-1].content += "\n```JSON\n{\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content="```JSON\n{\n" - )) + prompt_messages.append(UserPromptMessage(content="```JSON\n{\n")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -113,10 +163,10 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -126,36 +176,53 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: instance = ErnieBotModel( - api_key=credentials['api_key'], - secret_key=credentials['secret_key'], + api_key=credentials["api_key"], + secret_key=credentials["secret_key"], ) - user = user if user else 'ErnieBotDefault' + user = user or "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ ErnieMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60, tools=tools, stop=stop, user=user) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + stop=stop, + user=user, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) @@ -180,43 +247,49 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ErnieMessage) -> LLMResult: + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: ErnieMessage + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[ErnieMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[ErnieMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) else: @@ -225,11 +298,8 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=message.content, tool_calls=[]), + finish_reason=message.stop_reason or None, ), ) diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 10ac1a1861..4d6f6dccd0 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -29,38 +29,38 @@ class TextEmbedding: class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): access_token = self._get_access_token() - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url, data=dumps(body), headers=headers) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") return self._handle_embed_response(model, resp) def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: if len(texts) == 0: - raise BadRequestError('The number of texts should not be zero.') + raise BadRequestError("The number of texts should not be zero.") body = { - 'input': texts, - 'user_id': user, + "input": texts, + "user_id": user, } return body def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - embeddings = [v['embedding'] for v in data['data']] - _usage = data['usage'] - tokens = _usage['prompt_tokens'] - total_tokens = _usage['total_tokens'] + embeddings = [v["embedding"] for v in data["data"]] + _usage = data["usage"] + tokens = _usage["prompt_tokens"] + total_tokens = _usage["total_tokens"] return embeddings, tokens, total_tokens @@ -69,22 +69,23 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: return WenxinTextEmbedding(api_key, secret_key) - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ - Invoke text embedding model + Invoke text embedding model - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :return: embeddings result - """ + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) - user = user if user else 'ErnieBotDefault' + user = user or "ErnieBotDefault" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -94,7 +95,6 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): used_total_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -110,9 +110,8 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): _iter = range(0, len(inputs), max_chunks) for i in _iter: embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( - model, - inputs[i: i + max_chunks], - user) + model, inputs[i : i + max_chunks], user + ) used_tokens += _used_tokens used_total_tokens += _total_used_tokens batched_embeddings += embeddings_batch @@ -142,12 +141,12 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: Mapping) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -164,10 +163,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -178,7 +174,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index 04845d06bc..895af20bc8 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class WenxinProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `ernie-bot` model for validate, - model_instance.validate_credentials( - model='ernie-bot', - credentials=credentials - ) + model_instance.validate_credentials(model="ernie-bot", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py index 0fbd0f55ec..bd074e0477 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -18,40 +18,37 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): + +class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 4760e8f118..4fadda5df5 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -65,88 +65,109 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 + if "temperature" in model_parameters: + if model_parameters["temperature"] < 0.01: + model_parameters["temperature"] = 0.01 + elif model_parameters["temperature"] > 1.0: + model_parameters["temperature"] = 0.99 return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), - ) + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), + ), ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' + if "completion_type" not in credentials: + if "chat" in extra_param.model_ability: + credentials["completion_type"] = "chat" + elif "generate" in extra_param.model_ability: + credentials["completion_type"] = "completion" else: raise ValueError( - f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + f"xinference model ability {extra_param.model_ability} is not supported," + f" check if you have the right model type" + ) if extra_param.support_function_call: - credentials['support_function_call'] = True + credentials["support_function_call"] = True if extra_param.support_vision: - credentials['support_vision'] = True + credentials["support_vision"] = True if extra_param.context_length: - credentials['context_length'] = extra_param.context_length + credentials["context_length"] = extra_param.context_length except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except Exception as e: raise e - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -162,10 +183,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -217,30 +238,30 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -248,18 +269,14 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: - if isinstance(item, UserPromptMessage): - text += item.content - elif isinstance(item, SystemPromptMessage): - text += item.content - elif isinstance(item, AssistantPromptMessage): + if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage): text += item.content else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: @@ -275,19 +292,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -297,7 +308,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -312,151 +323,145 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY, use_template=DefaultParameterName.PRESENCE_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they ' - 'appear in the text so far, increasing the model\'s likelihood to talk about new topics.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they " + "appear in the text so far, increasing the model's likelihood to talk about new topics.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚," + "从而增加模型谈论新话题的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY, use_template=DefaultParameterName.FREQUENCY_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their ' - 'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the ' - 'same line verbatim.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their " + "existing frequency in the text so far, decreasing the model's likelihood to repeat the " + "same line verbatim.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚," + "从而降低模型逐字重复相同内容的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 - ) + precision=2, + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') else: extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'chat' in extra_args.model_ability: + if "chat" in extra_args.model_ability: completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: + elif "generate" in extra_args.model_ability: completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + raise ValueError(f"xinference model ability {extra_args.model_ability} is not supported") features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") - api_key = credentials.get('api_key') or "abc" + api_key = credentials.get("api_key") or "abc" client = OpenAI( base_url=f'{credentials["server_url"]}/v1', @@ -466,34 +471,29 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) xinference_client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_model = xinference_client.get_model(credentials['model_uid']) + xinference_model = xinference_client.get_model(credentials["model_uid"]) generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'presence_penalty': model_parameters.get('presence_penalty', 0.0), - 'frequency_penalty': model_parameters.get('frequency_penalty', 0.0), + "temperature": model_parameters.get("temperature", 1.0), + "top_p": model_parameters.get("top_p", 0.7), + "max_tokens": model_parameters.get("max_tokens", 512), + "presence_penalty": model_parameters.get("presence_penalty", 0.0), + "frequency_penalty": model_parameters.get("frequency_penalty", 0.0), } if stop: - generate_config['stop'] = stop + generate_config["stop"] = stop if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] - vision = credentials.get('support_vision', False) + generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] + vision = credentials.get("support_vision", False) if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, @@ -501,34 +501,34 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) if stream: if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + raise InvokeBadRequestError("xinference tool calls does not support stream mode") + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], prompt=self._convert_prompt_message_to_text(prompt_messages), stream=stream, user=user, **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_completion_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + raise NotImplementedError(f"xinference model handle type {type(xinference_model)} is not supported") - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -539,21 +539,19 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -563,23 +561,25 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -588,22 +588,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # convert tool call to assistant message tool call tool_calls = assistant_message.tool_calls - assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) + assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or []) function_call = assistant_message.function_call if function_call: assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls + content=assistant_message.content, tool_calls=assistant_prompt_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -615,13 +615,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -629,7 +634,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -646,32 +651,31 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: @@ -687,11 +691,16 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion, + ) -> LLMResult: """ - handle normal completion generate response + handle normal completion generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -699,14 +708,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): assistant_message = resp.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[assistant_prompt_message], tools=[], is_completion_model=True ) @@ -724,13 +728,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + def _handle_completion_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion], + ) -> Generator: """ - handle stream completion generate response + handle stream completion generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -739,40 +748,33 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: - if delta.text is None or delta.text == '': + if delta.text is None or delta.text == "": continue yield LLMResultChunk( @@ -807,15 +809,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index d809537479..8f18bc42d2 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -22,10 +22,16 @@ class XinferenceRerankModel(RerankModel): Model class for Xinference rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -39,24 +45,15 @@ class XinferenceRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) + return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): - server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + server_url = server_url.removesuffix("/") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} - params = { - 'documents': docs, - 'query': query, - 'top_n': top_n, - 'return_documents': True - } + params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} try: handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) response = handle.rerank(**params) @@ -69,27 +66,24 @@ class XinferenceRerankModel(RerankModel): response = handle.rerank(**params) rerank_documents = [] - for idx, result in enumerate(response['results']): + for idx, result in enumerate(response["results"]): # format document - index = result['index'] - page_content = result['document'] if isinstance(result['document'], str) else result['document']['text'] + index = result["index"] + page_content = result["document"] if isinstance(result["document"], str) else result["document"]["text"] rerank_document = RerankDocument( index=index, text=page_content, - score=result['relevance_score'], + score=result["relevance_score"], ) # score threshold check if score_threshold is not None: - if result['relevance_score'] >= score_threshold: + if result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -100,34 +94,34 @@ class XinferenceRerankModel(RerankModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a rerank model') + "please check model type, the model you want to invoke is not a rerank model" + ) self.invoke( model=model, credentials=credentials, query="Whose kasumi", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -143,53 +137,38 @@ class XinferenceRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): - def rerank( - self, - documents: list[str], - query: str, - top_n: Optional[int] = None, - max_chunks_per_doc: Optional[int] = None, - return_documents: Optional[bool] = None, - **kwargs + self, + documents: list[str], + query: str, + top_n: Optional[int] = None, + max_chunks_per_doc: Optional[int] = None, + return_documents: Optional[bool] = None, + **kwargs, ): url = f"{self._base_url}/v1/rerank" request_body = { @@ -205,8 +184,6 @@ class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): response = requests.post(url, json=request_body, headers=self.auth_headers) if response.status_code != 200: - raise InvokeServerUnavailableError( - f"Failed to rerank documents, detail: {response.json()['detail']}" - ) + raise InvokeServerUnavailableError(f"Failed to rerank documents, detail: {response.json()['detail']}") response_data = response.json() return response_data diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 62b77f22e5..a6c5b8a0a5 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -21,9 +21,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): Model class for Xinference speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -44,27 +42,27 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulAudioModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + "please check model type, the model you want to invoke is not a audio model" + ) audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self.invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -80,23 +78,11 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _speech2text_invoke( @@ -114,29 +100,28 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :param model: model name :param credentials: model credentials - :param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpe g,mpga, m4a, ogg, wav, or webm. + :param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, + mpga, m4a, ogg, wav, or webm. :param language: The language of the input audio. Supplying the input language in ISO-639-1 :param prompt: An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language. - :param response_format: The format of the transcript output, in one of these options: json, text, srt, verbose _json, or vtt. - :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. + :param response_format: The format of the transcript output, in one of these options: json, text, srt, + verbose_json, or vtt. + :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more + random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model will use + log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): - server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + server_url = server_url.removesuffix("/") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) response = handle.transcriptions( - audio=file, - language=language, - prompt=prompt, - response_format=response_format, - temperature=temperature + audio=file, language=language, prompt=prompt, response_format=response_format, temperature=temperature ) except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -145,17 +130,15 @@ class XinferenceSpeech2TextModel(Speech2TextModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 3a8d704c25..8043af1d6c 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ Model class for Xinference text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -41,12 +42,11 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): - server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + server_url = server_url.removesuffix("/") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) @@ -70,13 +70,11 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): embedding: List[float] """ - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = embeddings["usage"] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage + model=model, embeddings=[embedding["embedding"] for embedding in embeddings["data"]], usage=usage ) return result @@ -105,12 +103,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=server_url, model_uid=model_uid, @@ -118,9 +116,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ) if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens - if server_url.endswith('/'): - server_url = server_url[:-1] + credentials["max_tokens"] = extra_args.max_tokens + server_url = server_url.removesuffix("/") client = Client( base_url=server_url, @@ -133,32 +130,24 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(e) if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a text embedding model" + ) - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + raise CredentialsValidateFailedError(f"Failed to validate credentials for model {model}: {e}") except RuntimeError as e: raise CredentialsValidateFailedError(e) @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -172,10 +161,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -186,28 +172,26 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials and credentials["max_tokens"] or 512, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index 8cc99fef7c..10538b5788 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -19,92 +19,90 @@ from core.model_runtime.model_providers.xinference.xinference_helper import Xinf class XinferenceText2SpeechModel(TTSModel): - def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'ChatTTS': { - 'all': [ - {'name': 'Alloy', 'value': 'alloy'}, - {'name': 'Echo', 'value': 'echo'}, - {'name': 'Fable', 'value': 'fable'}, - {'name': 'Onyx', 'value': 'onyx'}, - {'name': 'Nova', 'value': 'nova'}, - {'name': 'Shimmer', 'value': 'shimmer'}, + "ChatTTS": { + "all": [ + {"name": "Alloy", "value": "alloy"}, + {"name": "Echo", "value": "echo"}, + {"name": "Fable", "value": "fable"}, + {"name": "Onyx", "value": "onyx"}, + {"name": "Nova", "value": "nova"}, + {"name": "Shimmer", "value": "shimmer"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ try: - if ("/" in credentials['model_uid'] or - "?" in credentials['model_uid'] or - "#" in credentials['model_uid']): + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'text-to-audio' not in extra_param.model_ability: + if "text-to-audio" not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a text-to-audio model') + "please check model type, the model you want to invoke is not a text-to-audio model" + ) if extra_param.model_family and extra_param.model_family in self.model_voices: - credentials['audio_model_name'] = extra_param.model_family + credentials["audio_model_name"] = extra_param.model_family else: - credentials['audio_model_name'] = '__default' + credentials["audio_model_name"] = "__default" self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -120,18 +118,16 @@ class XinferenceText2SpeechModel(TTSModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -147,40 +143,28 @@ class XinferenceText2SpeechModel(TTSModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = credentials.get('audio_model_name', '__default') + audio_model_name = credentials.get("audio_model_name", "__default") for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] else: all_voices = [] for lang, lang_voices in voices.items(): all_voices.extend(lang_voices) return all_voices - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] def _get_model_default_voice(self, model: str, credentials: dict) -> any: return "" @@ -194,8 +178,7 @@ class XinferenceText2SpeechModel(TTSModel): def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -205,48 +188,41 @@ class XinferenceText2SpeechModel(TTSModel): :param voice: model timbre :return: text translated to audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") try: - api_key = credentials.get('api_key') - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + api_key = credentials.get("api_key") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} handle = RESTfulAudioModelHandle( - credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + credentials["model_uid"], credentials["server_url"], auth_headers=auth_headers ) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit( - handle.speech, - input=sentences[i], - voice=voice, - response_format="mp3", - speed=1.0, - stream=False - ) - for i in range(len(sentences))] + futures = [ + executor.submit( + handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False + ) + for i in range(len(sentences)) + ] - for index, future in enumerate(futures): + for future in futures: response = future.result() for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] else: response = handle.speech( - input=content_text.strip(), - voice=voice, - response_format="mp3", - speed=1.0, - stream=False + input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False ) for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 151166f165..619ee1492a 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -18,9 +18,17 @@ class XinferenceModelExtraParameter: support_vision: bool = False model_family: Optional[str] - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, - model_family: Optional[str]) -> None: + def __init__( + self, + model_format: str, + model_handle_type: str, + model_ability: list[str], + support_function_call: bool, + support_vision: bool, + max_tokens: int, + context_length: int, + model_family: Optional[str], + ) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -30,9 +38,11 @@ class XinferenceModelExtraParameter: self.context_length = context_length self.model_family = model_family + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: @@ -40,16 +50,16 @@ class XinferenceHelper: with cache_lock: if model_uid not in cache: cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) + "expires": time() + 300, + "value": XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key), } - return cache[model_uid]['value'] + return cache[model_uid]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -58,55 +68,59 @@ class XinferenceHelper: @staticmethod def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ - get xinference model extra parameter like model_format and model_handle_type + get xinference model extra parameter like model_format and model_handle_type """ if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') + raise RuntimeError("model_uid is empty") - url = str(URL(server_url) / 'v1' / 'models' / model_uid) + url = str(URL(server_url) / "v1" / "models" / model_uid) - # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + # this method is surrounded by a lock, and default requests may hang forever, + # so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) - headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + raise RuntimeError( + f"get xinference model extra parameter failed, status code: {response.status_code}," + f" response: {response.text}" + ) response_json = response.json() - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) - model_family = response_json.get('model_family', None) + model_format = response_json.get("model_format", "ggmlv3") + model_ability = response_json.get("model_ability", []) + model_family = response_json.get("model_family", None) - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif response_json.get('model_type') == 'audio': - model_handle_type = 'audio' - if model_family and model_family in ['ChatTTS', 'CosyVoice', 'FishAudio']: - model_ability.append('text-to-audio') + if response_json.get("model_type") == "embedding": + model_handle_type = "embedding" + elif response_json.get("model_type") == "audio": + model_handle_type = "audio" + if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}: + model_ability.append("text-to-audio") else: - model_ability.append('audio-to-text') - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' + model_ability.append("audio-to-text") + elif model_format == "ggmlv3" and "chatglm" in response_json["model_name"]: + model_handle_type = "chatglm" + elif "generate" in model_ability: + model_handle_type = "generate" + elif "chat" in model_ability: + model_handle_type = "chat" else: - raise NotImplementedError('xinference model handle type is not supported') + raise NotImplementedError("xinference model handle type is not supported") - support_function_call = 'tools' in model_ability - support_vision = 'vision' in model_ability - max_tokens = response_json.get('max_tokens', 512) + support_function_call = "tools" in model_ability + support_vision = "vision" in model_ability + max_tokens = response_json.get("max_tokens", 512) - context_length = response_json.get('context_length', 2048) + context_length = response_json.get("context_length", 2048) return XinferenceModelExtraParameter( model_format=model_format, @@ -116,5 +130,5 @@ class XinferenceHelper: support_vision=support_vision, max_tokens=max_tokens, context_length=context_length, - model_family=model_family + model_family=model_family, ) diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index d33f38333b..5ab7fd126e 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -14,11 +14,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class YiLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) # yi-vl-plus not support system prompt yet. @@ -27,7 +33,9 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): for message in prompt_messages: if not isinstance(message, SystemPromptMessage): prompt_message_except_system.append(message) - return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) + return super()._invoke( + model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream + ) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -36,8 +44,7 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -55,8 +62,9 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -76,10 +84,10 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -110,10 +118,10 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.lingyiwanwu.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.lingyiwanwu.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/yi/yi.py b/api/core/model_runtime/model_providers/yi/yi.py index 691c7aa371..9599acb22b 100644 --- a/api/core/model_runtime/model_providers/yi/yi.py +++ b/api/core/model_runtime/model_providers/yi/yi.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class YiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class YiProvider(ModelProvider): # Use `yi-34b-chat-0205` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='yi-34b-chat-0205', - credentials=credentials - ) + model_instance.validate_credentials(model="yi-34b-chat-0205", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhinao/llm/llm.py b/api/core/model_runtime/model_providers/zhinao/llm/llm.py index 6930a5ed01..befc3de021 100644 --- a/api/core/model_runtime/model_providers/zhinao/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhinao/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.360.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.360.cn/v1" diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.py b/api/core/model_runtime/model_providers/zhinao/zhinao.py index 44b36c9f51..2a263292f9 100644 --- a/api/core/model_runtime/model_providers/zhinao/zhinao.py +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class ZhinaoProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class ZhinaoProvider(ModelProvider): # Use `360gpt-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='360gpt-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="360gpt-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 3412d8100f..fa95232f71 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,8 +17,7 @@ class _CommonZhipuaiAI: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials.get("zhipuai_api_key"), + "api_key": credentials["api_key"] if "api_key" in credentials else credentials.get("zhipuai_api_key"), } return credentials_kwargs @@ -38,5 +37,5 @@ class _CommonZhipuaiAI: InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index b2cdc7ad7a..ea331701ab 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -31,16 +31,21 @@ And you should always end the block with a "```" to indicate the end of the JSON {{instructions}} -```JSON""" +```JSON""" # noqa: E501 class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -62,9 +67,9 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -94,8 +99,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # content="```JSON\n" # )) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -130,16 +140,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): "temperature": 0.5, }, tools=[], - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials_kwargs: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials_kwargs: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -154,15 +170,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): """ extra_model_kwargs = {} # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" - if stop and model!= 'glm-4v-plus': - extra_model_kwargs['stop'] = stop + if stop and model != "glm-4v-plus": + extra_model_kwargs["stop"] = stop - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: - raise ValueError('At least one message is required') + raise ValueError("At least one message is required") if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: @@ -172,13 +186,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() - if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in ('glm-4v', 'glm-4v-plus'): + if model not in {"glm-4v", "glm-4v-plus"}: # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue @@ -188,13 +202,14 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.USER + and copy_prompt_message.role == PromptMessageRole.USER + ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if copy_prompt_message.role == PromptMessageRole.USER: - new_prompt_messages.append(copy_prompt_message) - elif copy_prompt_message.role == PromptMessageRole.TOOL: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -208,77 +223,66 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v' or model == 'glm-4v-plus': + if model in {"glm-4v", "glm-4v-plus"}: params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: - params = { - 'model': model, - 'messages': [], - **model_parameters - } + params = {"model": model, "messages": [], **model_parameters} # glm model - if not model.startswith('chatglm'): - + if not model.startswith("chatglm"): for prompt_message in new_prompt_messages: if prompt_message.role == PromptMessageRole.TOOL: - params['messages'].append({ - 'role': 'tool', - 'content': prompt_message.content, - 'tool_call_id': prompt_message.tool_call_id - }) + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content, - 'tool_calls': [ - { - 'id': tool_call.id, - 'type': tool_call.type, - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments + params["messages"].append( + { + "role": "assistant", + "content": prompt_message.content, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls - ] - }) + for tool_call in prompt_message.tool_calls + ], + } + ) else: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content - }) + params["messages"].append({"role": "assistant", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) else: # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: - if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': - params['messages'][-1]['content'] += "\n\n" + prompt_message.content + if prompt_message.role in { + PromptMessageRole.SYSTEM, + PromptMessageRole.TOOL, + PromptMessageRole.USER, + }: + if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": + params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: - params['messages'].append({ - 'role': 'user', - 'content': prompt_message.content - }) + params["messages"].append({"role": "user", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) if tools and len(tools) > 0: - params['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] + params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] if stream: response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) @@ -287,47 +291,41 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], - model_parameters: dict): + def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict): messages = [ - { - 'role': message.role.value, - 'content': self._construct_glm_4v_messages(message.content) - } + {"role": message.role.value, "content": self._construct_glm_4v_messages(message.content)} for message in prompt_messages ] - params = { - 'model': model, - 'messages': messages, - **model_parameters - } + params = {"model": model, "messages": messages, **model_parameters} return params def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: if isinstance(prompt_message, str): - return [{'type': 'text', 'text': prompt_message}] + return [{"type": "text", "text": prompt_message}] return [ - {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} - if item.type == PromptMessageContentType.IMAGE else - {'type': 'text', 'text': item.data} - + {"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}} + if item.type == PromptMessageContentType.IMAGE + else {"type": "text", "text": item.data} for item in prompt_message ] def _remove_image_header(self, image: str) -> str: - if image.startswith('data:image'): - return image.split(',')[1] + if image.startswith("data:image"): + return image.split(",")[1] return image - def _handle_generate_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -336,12 +334,12 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - text = '' + text = "" assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -349,11 +347,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) - text += choice.message.content or '' + text += choice.message.content or "" prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens @@ -365,20 +363,20 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text, - tool_calls=assistant_tool_calls - ), + message=AssistantPromptMessage(content=text, tool_calls=assistant_tool_calls), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - responses: Generator[ChatCompletionChunk, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -387,19 +385,19 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_assistant_content = '' + full_assistant_content = "" for chunk in responses: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -407,17 +405,16 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_tool_calls + content=delta.delta.content or "", tool_calls=assistant_tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens @@ -429,24 +426,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason - ) + index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -464,27 +459,23 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) if tools and len(tools) > 0: text += "\n\nTools:" diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0f9fecfc72..ee20954381 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -14,9 +14,9 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Model class for ZhipuAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -27,16 +27,14 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) return TextEmbeddingResult( embeddings=embeddings, usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + model=model, ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -50,7 +48,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ if len(texts) == 0: return 0 - + total_num_tokens = 0 for text in texts: total_num_tokens += self._get_num_tokens_by_gpt2(text) @@ -68,15 +66,13 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -100,7 +96,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): embedding_used_tokens += response.usage.total_tokens return [list(map(float, e)) for e in embeddings], embedding_used_tokens - + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. @@ -111,8 +107,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Embeddings for the text. """ return self.embed_documents([text])[0] - - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -122,10 +118,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -136,7 +129,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py index c517d2dba5..e75aad6eb0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py @@ -19,12 +19,9 @@ class ZhipuaiProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='glm-4', - credentials=credentials - ) + model_instance.validate_credentials(model="glm-4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 4dcd03f551..bf9b093cb3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,4 +1,3 @@ - from .__version__ import __version__ from ._client import ZhipuAI from .core._errors import ( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index eb0ad332ca..659f38d7ff 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1,2 +1 @@ - -__version__ = 'v2.0.1' \ No newline at end of file +__version__ = "v2.0.1" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 6588d1dd68..df9e506095 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -20,14 +20,14 @@ class ZhipuAI(HttpClient): api_key: str def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if api_key is None: raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables") @@ -38,6 +38,7 @@ class ZhipuAI(HttpClient): if base_url is None: base_url = "https://open.bigmodel.cn/api/paas/v4" from .__version__ import __version__ + super().__init__( version=__version__, base_url=base_url, @@ -58,9 +59,7 @@ class ZhipuAI(HttpClient): return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} def __del__(self) -> None: - if (not hasattr(self, "_has_custom_http_client") - or not hasattr(self, "close") - or not hasattr(self, "_client")): + if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): # if the '__init__' method raised an error, self would not have client attr return diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index dab6dac5fe..1f80119739 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -17,25 +17,24 @@ class AsyncCompletions(BaseAPI): def __init__(self, client: ZhipuAI) -> None: super().__init__(client) - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AsyncTaskStatus: _cast_type = AsyncTaskStatus @@ -57,9 +56,7 @@ class AsyncCompletions(BaseAPI): "tools": tools, "tool_choice": tool_choice, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) @@ -71,16 +68,11 @@ class AsyncCompletions(BaseAPI): disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + _cast_type = Union[AsyncCompletion, AsyncTaskStatus] if disable_strict_validation: _cast_type = object return self._get( path=f"/async-result/{id}", cast_type=_cast_type, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout - ) + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), ) - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index 5c4ed4d1ba..ec29f33864 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -20,24 +20,24 @@ class Completions(BaseAPI): super().__init__(client) def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Completion | StreamResponse[ChatCompletionChunk]: _cast_type = Completion _stream_cls = StreamResponse[ChatCompletionChunk] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index 35d54592fd..2308a20451 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -18,16 +18,16 @@ class Embeddings(BaseAPI): super().__init__(client) def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + input: Union[str, list[str], list[int], list[list[int]]], + model: Union[str], + encoding_format: str | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EmbeddingsResponded: _cast_type = EmbeddingsResponded if disable_strict_validation: @@ -41,9 +41,7 @@ class Embeddings(BaseAPI): "user": user, "sensitive_word_check": sensitive_word_check, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 5deb8d08f3..f2ac74bffa 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -17,17 +17,16 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - file: FileTypes, - purpose: str, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + file: FileTypes, + purpose: str, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FileObject: if not is_file_content(file): prefix = f"Expected file input `{file!r}`" @@ -44,21 +43,19 @@ class Files(BaseAPI): "purpose": purpose, }, files=files, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FileObject, ) def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + purpose: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + after: str | NotGiven = NOT_GIVEN, + order: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFileObject: return self._get( "/files", diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc54a9ca45..dc30bd33ed 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -13,4 +13,3 @@ class FineTuning(BaseAPI): def __init__(self, client: "ZhipuAI") -> None: super().__init__(client) self.jobs = Jobs(client) - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index b860de192a..3d2e9208a1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -16,21 +16,20 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + training_file: str, + hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, + suffix: Optional[str] | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + validation_file: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._post( "/fine_tuning/jobs", @@ -42,34 +41,30 @@ class Jobs(BaseAPI): "validation_file": validation_file, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFineTuningJob: return self._get( "/fine_tuning/jobs", @@ -93,7 +88,6 @@ class Jobs(BaseAPI): extra_headers: Headers | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJobEvent: - return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}/events", cast_type=FineTuningJobEvent, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 8eae1216d0..2692b093af 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -18,21 +18,21 @@ class Images(BaseAPI): super().__init__(client) def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + prompt: str, + model: str | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + quality: Optional[str] | NotGiven = NOT_GIVEN, + response_format: Optional[str] | NotGiven = NOT_GIVEN, + size: Optional[str] | NotGiven = NOT_GIVEN, + style: Optional[str] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ImagesResponded: _cast_type = ImagesResponded if disable_strict_validation: @@ -50,11 +50,7 @@ class Images(BaseAPI): "user": user, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py index b7cf6bb7fd..7a91f9b796 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -75,7 +75,8 @@ Headers = Mapping[str, Union[str, Omit]] ResponseT = TypeVar( "ResponseT", - bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", + bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol," + " BinaryResponseContent]", ) # for user input files diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py index a2a438b8f3..1027c1bc5b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -17,7 +17,10 @@ __all__ = [ class ZhipuAIError(Exception): - def __init__(self, message: str, ) -> None: + def __init__( + self, + message: str, + ) -> None: super().__init__(message) @@ -31,24 +34,19 @@ class APIStatusError(Exception): self.status_code = response.status_code -class APIRequestFailedError(APIStatusError): - ... +class APIRequestFailedError(APIStatusError): ... -class APIAuthenticationError(APIStatusError): - ... +class APIAuthenticationError(APIStatusError): ... -class APIReachLimitError(APIStatusError): - ... +class APIReachLimitError(APIStatusError): ... -class APIInternalError(APIStatusError): - ... +class APIInternalError(APIStatusError): ... -class APIServerFlowExceedError(APIStatusError): - ... +class APIServerFlowExceedError(APIStatusError): ... class APIResponseError(Exception): @@ -67,16 +65,11 @@ class APIResponseValidationError(APIResponseError): status_code: int response: httpx.Response - def __init__( - self, - response: httpx.Response, - json_data: object | None, *, - message: str | None = None - ) -> None: + def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None: super().__init__( message=message or "Data returned by API invalid for expected schema.", request=response.request, - json_data=json_data + json_data=json_data, ) self.response = response self.status_code = response.status_code diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 65401f6c1c..5f7f6d04f2 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: return {key: val for key, val in merged.items() if val is not None} +from itertools import starmap + from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) @@ -48,13 +50,13 @@ class HttpClient: _default_stream_cls: type[StreamResponse[Any]] | None = None def __init__( - self, - *, - version: str, - base_url: URL, - timeout: Union[float, Timeout, None], - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, + self, + *, + version: str, + base_url: URL, + timeout: Union[float, Timeout, None], + custom_httpx_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if timeout is None or isinstance(timeout, NotGiven): if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: @@ -76,7 +78,6 @@ class HttpClient: self._custom_headers = custom_headers or {} def _prepare_url(self, url: str) -> URL: - sub_url = URL(url) if sub_url.is_relative_url: request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") @@ -86,16 +87,15 @@ class HttpClient: @property def _default_headers(self): - return \ - { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, - **self._custom_headers, - } + return { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", + "ZhipuAI-SDK-Ver": self._version, + "source_type": "zhipu-sdk-python", + "x-request-sdk": "zhipu-sdk-python", + **self._auth_headers, + **self._custom_headers, + } @property def _auth_headers(self): @@ -109,10 +109,7 @@ class HttpClient: return httpx_headers - def _prepare_request( - self, - request_param: ClientRequestParam - ) -> httpx.Request: + def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request: kwargs: dict[str, Any] = {} json_data = request_param.json_data headers = self._prepare_headers(request_param) @@ -164,8 +161,7 @@ class HttpClient: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - - items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) + items = flatten(list(starmap(self._object_to_formdata, data.items()))) serialized: dict[str, object] = {} for key, value in items: @@ -175,30 +171,25 @@ class HttpClient: return serialized def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + enable_stream: bool, + request_param: ClientRequestParam, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> HttpResponse: - http_response = HttpResponse( - raw_response=response, - cast_type=cast_type, - client=self, - enable_stream=enable_stream, - stream_cls=stream_cls + raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, stream_cls=stream_cls ) return http_response.parse() def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, + self, + *, + data: object, + cast_type: type[ResponseT], + response: httpx.Response, ) -> ResponseT: if data is None: return cast(ResponseT, None) @@ -225,12 +216,12 @@ class HttpClient: @retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES)) def request( - self, - *, - cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + params: ClientRequestParam, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: request = self._prepare_request(params) @@ -259,81 +250,79 @@ class HttpClient: ) def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - enable_stream: bool = False, + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + enable_stream: bool = False, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="get", url=path, **options) - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream - ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream) def post( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, - **options) - - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream, - stream_cls=stream_cls + opts = ClientRequestParam.construct( + method="post", json_data=body, files=make_httpx_files(files), url=path, **options ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream, stream_cls=stream_cls) + def patch( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def put( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), - **options) + opts = ClientRequestParam.construct( + method="put", url=path, json_data=body, files=make_httpx_files(files), **options + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def delete( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def _make_status_error(self, response) -> APIStatusError: @@ -355,11 +344,11 @@ class HttpClient: def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - extra_body: Body | None = None, - query: Query | None = None, + max_retries: int | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + extra_headers: Headers = None, + extra_body: Body | None = None, + query: Query | None = None, ) -> UserRequestInput: options: UserRequestInput = {} @@ -368,7 +357,7 @@ def make_user_request_input( if max_retries is not None: options["max_retries"] = max_retries if not isinstance(timeout, NotGiven): - options['timeout'] = timeout + options["timeout"] = timeout if query is not None: options["params"] = query if extra_body is not None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index a3f49ba846..ac459151fc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -35,17 +35,14 @@ class ClientRequestParam: @classmethod def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : - kwargs: dict[str, Any] = { - key: remove_notgiven_indict(value) for key, value in values.items() - } + cls, + _fields_set: set[str] | None = None, + **values: Unpack[UserRequestInput], + ) -> ClientRequestParam: + kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()} client = cls() client.__dict__.update(kwargs) return client model_construct = construct - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 2f831b6fc9..56e60a7934 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -26,13 +26,13 @@ class HttpResponse(Generic[R]): http_response: httpx.Response def __init__( - self, - *, - raw_response: httpx.Response, - cast_type: type[R], - client: HttpClient, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + raw_response: httpx.Response, + cast_type: type[R], + client: HttpClient, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: self._cast_type = cast_type self._client = client @@ -52,8 +52,8 @@ class HttpResponse(Generic[R]): self._stream_cls( cast_type=cast(type, get_args(self._stream_cls)[0]), response=self.http_response, - client=self._client - ) + client=self._client, + ), ) return self._parsed cast_type = self._cast_type diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 66afbfd107..ec2745d059 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -16,16 +16,15 @@ if TYPE_CHECKING: class StreamResponse(Generic[ResponseT]): - response: httpx.Response _cast_type: type[ResponseT] def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + client: HttpClient, ) -> None: self.response = response self._cast_type = cast_type @@ -39,7 +38,6 @@ class StreamResponse(Generic[ResponseT]): yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: - sse_line_parser = SSELineParser() iterator = sse_line_parser.iter_lines(self.response.iter_lines()) @@ -63,11 +61,7 @@ class StreamResponse(Generic[ResponseT]): class Event: def __init__( - self, - event: str | None = None, - data: str | None = None, - id: str | None = None, - retry: int | None = None + self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None ): self._event = event self._data = data @@ -76,21 +70,28 @@ class Event: def __repr__(self): data_len = len(self._data) if self._data else 0 - return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + return ( + f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + ) @property - def event(self): return self._event + def event(self): + return self._event @property - def data(self): return self._data + def data(self): + return self._data - def json_data(self): return json.loads(self._data) + def json_data(self): + return json.loads(self._data) @property - def id(self): return self._id + def id(self): + return self._id @property - def retry(self): return self._retry + def retry(self): + return self._retry class SSELineParser: @@ -107,19 +108,11 @@ class SSELineParser: def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: for line in lines: - line = line.rstrip('\n') + line = line.rstrip("\n") if not line: - if self._event is None and \ - not self._data and \ - self._id is None and \ - self._retry is None: + if self._event is None and not self._data and self._id is None and self._retry is None: continue - sse_event = Event( - event=self._event, - data='\n'.join(self._data), - id=self._id, - retry=self._retry - ) + sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry) self._event = None self._data = [] self._id = None @@ -134,8 +127,7 @@ class SSELineParser: field, _p, value = line.partition(":") - if value.startswith(' '): - value = value[1:] + value = value.removeprefix(" ") if field == "data": self._data.append(value) elif field == "event": diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index f22f32d251..a0645b0916 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -20,4 +20,4 @@ class AsyncCompletion(BaseModel): model: Optional[str] = None task_status: str choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file + usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index b2a847c50c..4b3a929a2b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -41,5 +41,3 @@ class Completion(BaseModel): request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 917bda7576..75f76fe969 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -6,7 +6,6 @@ __all__ = ["FileObject"] class FileObject(BaseModel): - id: Optional[str] = None bytes: Optional[int] = None created_at: Optional[int] = None @@ -18,7 +17,6 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): - object: Optional[str] = None data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py index af0991892e..416f516ef7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations -from .fine_tuning_job import FineTuningJob as FineTuningJob -from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob -from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent +from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob +from .fine_tuning_job_event import FineTuningJobEvent diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 71c00eaff0..1d3930286b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -2,7 +2,7 @@ from typing import Optional, Union from pydantic import BaseModel -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] class Error(BaseModel): diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index fe705d6943..029ec1a581 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -4,9 +4,9 @@ from core.model_runtime.entities.provider_entities import CredentialFormSchema, class CommonValidator: - def _validate_and_filter_credential_form_schemas(self, - credential_form_schemas: list[CredentialFormSchema], - credentials: dict) -> dict: + def _validate_and_filter_credential_form_schemas( + self, credential_form_schemas: list[CredentialFormSchema], credentials: dict + ) -> dict: need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: @@ -36,8 +36,9 @@ class CommonValidator: return validated_credentials - def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ - -> Optional[str]: + def _validate_credential_form_schema( + self, credential_form_schema: CredentialFormSchema, credentials: dict + ) -> Optional[str]: """ Validate credential form schema @@ -49,7 +50,7 @@ class CommonValidator: if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: - raise ValueError(f'Variable {credential_form_schema.variable} is required') + raise ValueError(f"Variable {credential_form_schema.variable} is required") else: # Get the value of default if credential_form_schema.default: @@ -65,23 +66,26 @@ class CommonValidator: # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: - raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + raise ValueError( + f"Variable {credential_form_schema.variable} length should not" + f" greater than {credential_form_schema.max_length}" + ) # check the type of value if not isinstance(value, str): - raise ValueError(f'Variable {credential_form_schema.variable} should be string') + raise ValueError(f"Variable {credential_form_schema.variable} should be string") - if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + raise ValueError(f"Variable {credential_form_schema.variable} is not in options") if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ['true', 'false']: - raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + if value.lower() not in {"true", "false"}: + raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = True if value.lower() == 'true' else False + value = True if value.lower() == "true" else False return value diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index c4786fad5d..7d1644d134 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -4,7 +4,6 @@ from core.model_runtime.schema_validators.common_validator import CommonValidato class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): self.model_type = model_type self.model_credential_schema = model_credential_schema diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index c945016534..6dff2428ca 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -3,7 +3,6 @@ from core.model_runtime.schema_validators.common_validator import CommonValidato class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 5078f00bfa..ec1bad5698 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,11 +18,10 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any -) -> Any: +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) + # Taken from Pydantic v1 as is def isoformat(o: Union[datetime.date, datetime.time]) -> str: return o.isoformat() @@ -82,11 +81,9 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]] + type_encoder_map: dict[Any, Callable[[Any], Any]], ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( - tuple - ) + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) for type_, encoder in type_encoder_map.items(): encoders_by_class_tuples[encoder] += (type_,) return encoders_by_class_tuples @@ -149,17 +146,13 @@ def jsonable_encoder( if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): - return format(obj, 'f') + return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) for key, value in obj.items(): if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) + (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (value is not None or not exclude_none) and key in allowed_keys ): diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index c68a554471..2067092d80 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -3,7 +3,7 @@ from pydantic import BaseModel def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, 'model_dump'): + if hasattr(pydantic, "model_dump"): return pydantic.model_dump(model) else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index f96e2a1c21..094ad78636 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -44,32 +44,29 @@ class ApiModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - params = ModerationInputParams( - app_id=self.app_id, - inputs=inputs, - query=query - ) + if self.config["inputs_config"]["enabled"]: + params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) return ModerationInputsResult(**result) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - params = ModerationOutputParams( - app_id=self.app_id, - text=text - ) + if self.config["outputs_config"]["enabled"]: + params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) return ModerationOutputsResult(**result) - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) @@ -80,9 +77,10 @@ class ApiModeration(Moderation): @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 757dd2ab46..60898d5547 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -8,8 +8,8 @@ from core.extension.extensible import Extensible, ExtensionModule class ModerationAction(Enum): - DIRECT_OUTPUT = 'direct_output' - OVERRIDDEN = 'overridden' + DIRECT_OUTPUT = "direct_output" + OVERRIDDEN = "overridden" class ModerationInputsResult(BaseModel): @@ -31,6 +31,7 @@ class Moderation(Extensible, ABC): """ The base class of moderation. """ + module: ExtensionModule = ExtensionModule.MODERATION def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: @@ -75,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): @@ -110,5 +111,5 @@ class Moderation(Extensible, ABC): raise ValueError("outputs_config.preset_response must be less than 100 characters") -class ModerationException(Exception): +class ModerationError(Exception): pass diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46dfacbc9e..46d3963bd0 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -2,7 +2,7 @@ import logging from typing import Optional from core.app.app_config.entities import AppConfig -from core.moderation.base import ModerationAction, ModerationException +from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask @@ -13,13 +13,14 @@ logger = logging.getLogger(__name__) class InputModeration: def check( - self, app_id: str, + self, + app_id: str, tenant_id: str, app_config: AppConfig, inputs: dict, query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -39,10 +40,7 @@ class InputModeration: moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( - name=moderation_type, - app_id=app_id, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config.config + name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config ) with measure_time() as timer: @@ -55,7 +53,7 @@ class InputModeration: message_id=message_id, moderation_result=moderation_result, inputs=inputs, - timer=timer + timer=timer, ) ) @@ -63,7 +61,7 @@ class InputModeration: return False, inputs, query if moderation_result.action == ModerationAction.DIRECT_OUTPUT: - raise ModerationException(moderation_result.preset_response) + raise ModerationError(moderation_result.preset_response) elif moderation_result.action == ModerationAction.OVERRIDDEN: inputs = moderation_result.inputs query = moderation_result.query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index ca562ad987..dc6a7ec564 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -25,41 +25,38 @@ class KeywordsModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] flagged = self._is_violated(inputs, keywords_list) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: + if self.config["outputs_config"]["enabled"]: # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated({'text': text}, keywords_list) - preset_response = self.config['outputs_config']['preset_response'] + flagged = self._is_violated({"text": text}, keywords_list) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: - for value in inputs.values(): - if self._check_keywords_in_value(keywords_list, value): - return True + return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) - return False - - def _check_keywords_in_value(self, keywords_list, value): - for keyword in keywords_list: - if keyword.lower() in value.lower(): - return True - return False + def _check_keywords_in_value(self, keywords_list, value) -> bool: + return any(keyword.lower() in value.lower() for keyword in keywords_list) diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index fee51007eb..6465de23b9 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -21,37 +21,36 @@ class OpenAIModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query flagged = self._is_violated(inputs) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - flagged = self._is_violated({'text': text}) - preset_response = self.config['outputs_config']['preset_response'] + if self.config["outputs_config"]["enabled"]: + flagged = self._is_violated({"text": text}) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict): - text = '\n'.join(str(inputs.values())) + text = "\n".join(str(inputs.values())) model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider="openai", - model_type=ModelType.MODERATION, - model="text-moderation-stable" + tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable" ) - openai_moderation = model_instance.invoke_moderation( - text=text - ) + openai_moderation = model_instance.invoke_moderation(text=text) return openai_moderation diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 69e28770c3..d8d794be18 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -29,7 +29,7 @@ class OutputModeration(BaseModel): thread: Optional[threading.Thread] = None thread_running: bool = True - buffer: str = '' + buffer: str = "" is_final_chunk: bool = False final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -50,11 +50,7 @@ class OutputModeration(BaseModel): self.buffer = completion self.is_final_chunk = True - result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=completion - ) + result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) if not result or not result.flagged: return completion @@ -65,21 +61,19 @@ class OutputModeration(BaseModel): final_output = result.text if public_event: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) return final_output def start_thread(self) -> threading.Thread: buffer_size = dify_config.MODERATION_BUFFER_SIZE - thread = threading.Thread(target=self.worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE - }) + thread = threading.Thread( + target=self.worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, + }, + ) thread.start() @@ -104,9 +98,7 @@ class OutputModeration(BaseModel): current_length = buffer_length result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=moderation_buffer + tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer ) if not result or not result.flagged: @@ -116,16 +108,11 @@ class OutputModeration(BaseModel): final_output = result.preset_response self.final_output = final_output else: - final_output = result.text + self.buffer[len(moderation_buffer):] + final_output = result.text + self.buffer[len(moderation_buffer) :] # trigger replace event if self.thread_running: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) if result.action == ModerationAction.DIRECT_OUTPUT: break @@ -133,10 +120,7 @@ class OutputModeration(BaseModel): def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: try: moderation_factory = ModerationFactory( - name=self.rule.type, - app_id=app_id, - tenant_id=tenant_id, - config=self.rule.config + name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config ) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index c7af8e2963..f7b882fc71 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -23,4 +23,4 @@ class BaseTraceInstance(ABC): Abstract method to trace activities. Subclasses must implement specific tracing logic for activities. """ - ... \ No newline at end of file + ... diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 221e6239ab..5c79867571 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -4,14 +4,15 @@ from pydantic import BaseModel, ValidationInfo, field_validator class TracingProviderEnum(Enum): - LANGFUSE = 'langfuse' - LANGSMITH = 'langsmith' + LANGFUSE = "langfuse" + LANGSMITH = "langsmith" class BaseTracingConfig(BaseModel): """ Base model class for tracing """ + ... @@ -19,16 +20,18 @@ class LangfuseConfig(BaseTracingConfig): """ Model class for Langfuse tracing config. """ + public_key: str secret_key: str - host: str = 'https://api.langfuse.com' + host: str = "https://api.langfuse.com" @field_validator("host") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.langfuse.com' - if not v.startswith('https://') and not v.startswith('http://'): - raise ValueError('host must start with https:// or http://') + v = "https://api.langfuse.com" + if not v.startswith("https://") and not v.startswith("http://"): + raise ValueError("host must start with https:// or http://") return v @@ -37,15 +40,17 @@ class LangSmithConfig(BaseTracingConfig): """ Model class for Langsmith tracing config. """ + api_key: str project: str - endpoint: str = 'https://api.smith.langchain.com' + endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.smith.langchain.com' - if not v.startswith('https://'): - raise ValueError('endpoint must start with https://') + v = "https://api.smith.langchain.com" + if not v.startswith("https://"): + raise ValueError("endpoint must start with https://") return v diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a1443f0691..f27a0af6e0 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel): metadata: dict[str, Any] @field_validator("inputs", "outputs") + @classmethod def ensure_type(cls, v): if v is None: return None @@ -23,6 +24,7 @@ class BaseTraceInfo(BaseModel): else: return "" + class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any conversation_id: Optional[str] = None @@ -98,23 +100,24 @@ class GenerateNameTraceInfo(BaseTraceInfo): conversation_id: Optional[str] = None tenant_id: str + trace_info_info_map = { - 'WorkflowTraceInfo': WorkflowTraceInfo, - 'MessageTraceInfo': MessageTraceInfo, - 'ModerationTraceInfo': ModerationTraceInfo, - 'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo, - 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, - 'ToolTraceInfo': ToolTraceInfo, - 'GenerateNameTraceInfo': GenerateNameTraceInfo, + "WorkflowTraceInfo": WorkflowTraceInfo, + "MessageTraceInfo": MessageTraceInfo, + "ModerationTraceInfo": ModerationTraceInfo, + "SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo, + "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, + "ToolTraceInfo": ToolTraceInfo, + "GenerateNameTraceInfo": GenerateNameTraceInfo, } class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation' - WORKFLOW_TRACE = 'workflow' - MESSAGE_TRACE = 'message' - MODERATION_TRACE = 'moderation' - SUGGESTED_QUESTION_TRACE = 'suggested_question' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' - TOOL_TRACE = 'tool' - GENERATE_NAME_TRACE = 'generate_conversation_name' + CONVERSATION_TRACE = "conversation" + WORKFLOW_TRACE = "workflow" + MESSAGE_TRACE = "message" + MODERATION_TRACE = "moderation" + SUGGESTED_QUESTION_TRACE = "suggested_question" + DATASET_RETRIEVAL_TRACE = "dataset_retrieval" + TOOL_TRACE = "tool" + GENERATE_NAME_TRACE = "generate_conversation_name" diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index af7661f0af..447b799f1f 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -196,6 +198,7 @@ class GenerationUsage(BaseModel): totalCost: Optional[float] = None @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel): model_config = ConfigDict(protected_namespaces=()) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a0f3ac7f86..6aefbec9aa 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id + trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") if trace_info.message_id: trace_id = trace_info.message_id @@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( - id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), + id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, @@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=trace_info.end_time, metadata=trace_info.metadata, level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", ) self.add_span(langfuse_span_data=workflow_span_data) else: @@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance): else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} - created_at = node_execution.created_at if node_execution.created_at else datetime.now() + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=finished_at, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", - parent_observation_id=( - trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id - ), + status_message=trace_info.error or "", + parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), ) else: span_data = LangfuseSpan( @@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=finished_at, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", ) self.add_span(langfuse_span_data=span_data) @@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance): output=outputs, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", usage=generation_usage, ) @@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance): output=message_data.answer, metadata=metadata, level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), - status_message=message_data.error if message_data.error else "", + status_message=message_data.error or "", usage=generation_usage, ) @@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=trace_info.end_time, metadata=trace_info.metadata, level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), - status_message=message_data.error if message_data.error else "", + status_message=message_data.error or "", usage=generation_usage, ) diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index f3fc46d99a..05c932fb99 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -35,49 +35,32 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field( - None, description="Serialized data of the run" - ) + serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field( - None, description="Session ID associated with the run" - ) - session_name: Optional[str] = Field( - None, description="Session name associated with the run" - ) - reference_example_id: Optional[str] = Field( - None, description="Reference example ID associated with the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + session_id: Optional[str] = Field(None, description="Session ID associated with the run") + session_name: Optional[str] = Field(None, description="Session name associated with the run") + reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name values = info.data if v == {} or v is None: return v usage_metadata = { - "input_tokens": values.get('input_tokens', 0), - "output_tokens": values.get('output_tokens', 0), - "total_tokens": values.get('total_tokens', 0), + "input_tokens": values.get("input_tokens", 0), + "output_tokens": values.get("output_tokens", 0), + "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) if isinstance(v, str): @@ -133,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): return v return v + @classmethod @field_validator("start_time", "end_time") def format_time(cls, v, info: ValidationInfo): if not isinstance(v, datetime): @@ -143,25 +127,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") end_time: Optional[datetime | str] = Field(None, description="End time of the run") error: Optional[str] = Field(None, description="Error message of the run") inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 9cbc805fe7..37cbea13fd 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance): langsmith_run = LangSmithRunModel( file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, - id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, inputs=trace_info.workflow_run_inputs, run_type=LangSmithRunType.tool, @@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance): }, error=trace_info.error, tags=["workflow"], - parent_run_id=trace_info.message_id if trace_info.message_id else None, + parent_run_id=trace_info.message_id or None, ) self.add_run(langsmith_run) @@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance): else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} - created_at = node_execution.created_at if node_execution.created_at else datetime.now() + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -159,8 +159,8 @@ class LangSmithDataTrace(BaseTraceInstance): run_type = LangSmithRunType.llm metadata.update( { - 'ls_provider': process_data.get('model_provider', ''), - 'ls_model_name': process_data.get('model_name', ''), + "ls_provider": process_data.get("model_provider", ""), + "ls_model_name": process_data.get("model_name", ""), } ) elif node_type == "knowledge-retrieval": @@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance): extra={ "metadata": metadata, }, - parent_run_id=trace_info.workflow_app_log_id - if trace_info.workflow_app_log_id - else trace_info.workflow_run_id, + parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, tags=["node_execution"], ) @@ -385,12 +383,10 @@ class LangSmithDataTrace(BaseTraceInstance): start_time=datetime.now(), ) - project_url = self.langsmith_client.get_run_url(run=run_data, - project_id=self.project_id, - project_name=self.project_name) - return project_url.split('/r/')[0] + project_url = self.langsmith_client.get_run_url( + run=run_data, project_id=self.project_id, project_name=self.project_name + ) + return project_url.split("/r/")[0] except Exception as e: logger.debug(f"LangSmith get run url failed: {str(e)}") raise ValueError(f"LangSmith get run url failed: {str(e)}") - - diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index aefab6ed16..6f17bade97 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -36,17 +36,17 @@ from tasks.ops_trace_task import process_trace_tasks provider_config_map = { TracingProviderEnum.LANGFUSE.value: { - 'config_class': LangfuseConfig, - 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host', 'project_key'], - 'trace_instance': LangFuseDataTrace + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, }, TracingProviderEnum.LANGSMITH.value: { - 'config_class': LangSmithConfig, - 'secret_keys': ['api_key'], - 'other_keys': ['project', 'endpoint'], - 'trace_instance': LangSmithDataTrace - } + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + }, } @@ -64,14 +64,17 @@ class OpsTraceManager: :return: encrypted tracing configuration """ # Get the configuration class and the keys that require encryption - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} # Encrypt necessary keys for key in secret_keys: if key in tracing_config: - if '*' in tracing_config[key]: + if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config new_config[key] = current_trace_config.get(key, tracing_config[key]) else: @@ -94,8 +97,11 @@ class OpsTraceManager: :param tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in tracing_config: @@ -114,8 +120,11 @@ class OpsTraceManager: :param decrypt_tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in decrypt_tracing_config: @@ -133,9 +142,11 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None @@ -164,21 +175,21 @@ class OpsTraceManager: if app_id is None: return None - app: App = db.session.query(App).filter( - App.id == app_id - ).first() + app: App = db.session.query(App).filter(App.id == app_id).first() app_ops_trace_config = json.loads(app.tracing) if app.tracing else None if app_ops_trace_config is not None: - tracing_provider = app_ops_trace_config.get('tracing_provider') + tracing_provider = app_ops_trace_config.get("tracing_provider") else: return None # decrypt_token decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) - if app_ops_trace_config.get('enabled'): - trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \ - provider_config_map[tracing_provider]['config_class'] + if app_ops_trace_config.get("enabled"): + trace_instance, config_class = ( + provider_config_map[tracing_provider]["trace_instance"], + provider_config_map[tracing_provider]["config_class"], + ) tracing_instance = trace_instance(config_class(**decrypt_trace_config)) return tracing_instance @@ -192,9 +203,11 @@ class OpsTraceManager: conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if conversation_data.app_model_config_id: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation_data.app_model_config_id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .first() + ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs @@ -210,7 +223,7 @@ class OpsTraceManager: :return: """ # auth check - if tracing_provider not in provider_config_map.keys() and tracing_provider is not None: + if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: App = db.session.query(App).filter(App.id == app_id).first() @@ -231,10 +244,7 @@ class OpsTraceManager: """ app: App = db.session.query(App).filter(App.id == app_id).first() if not app.tracing: - return { - "enabled": False, - "tracing_provider": None - } + return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) return app_trace_config @@ -246,8 +256,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() @@ -259,8 +271,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_key() @@ -272,8 +286,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_url() @@ -287,7 +303,7 @@ class TraceTask: conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, - **kwargs + **kwargs, ): self.trace_type = trace_type self.message_id = message_id @@ -310,9 +326,7 @@ class TraceTask: self.workflow_run, self.conversation_id, self.user_id ), TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( - self.message_id, self.timer, **self.kwargs - ), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( self.message_id, self.timer, **self.kwargs ), @@ -337,32 +351,29 @@ class TraceTask: workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = ( - json.loads(workflow_run.inputs) if workflow_run.inputs else {} - ) - workflow_run_outputs = ( - json.loads(workflow_run.outputs) if workflow_run.outputs else {} - ) + workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} + workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_version = workflow_run.version - error = workflow_run.error if workflow_run.error else "" + error = workflow_run.error or "" total_tokens = workflow_run.total_tokens - file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else [] + file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - tenant_id=tenant_id, - app_id=workflow_run.app_id, - workflow_run_id=workflow_run.id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog) + .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) + .first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None # get message_id - message_data = db.session.query(Message.id).filter_by( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id - ).first() + message_data = ( + db.session.query(Message.id) + .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) + .first() + ) message_id = str(message_data.id) if message_data else None metadata = { @@ -441,7 +452,7 @@ class TraceTask: message_tokens=message_tokens, answer_tokens=message_data.answer_tokens, total_tokens=message_tokens + message_data.answer_tokens, - error=message_data.error if message_data.error else "", + error=message_data.error or "", inputs=inputs, outputs=message_data.answer, file_list=file_list, @@ -470,13 +481,13 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( - message_id=workflow_app_log_id if workflow_app_log_id else message_id, + message_id=workflow_app_log_id or message_id, inputs=inputs, message_data=message_data.to_dict(), flagged=moderation_result.flagged, @@ -510,13 +521,13 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( - message_id=workflow_app_log_id if workflow_app_log_id else message_id, + message_id=workflow_app_log_id or message_id, message_data=message_data.to_dict(), inputs=message_data.message, outputs=message_data.answer, @@ -558,7 +569,7 @@ class TraceTask: dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, - inputs=message_data.query if message_data.query else message_data.inputs, + inputs=message_data.query or message_data.inputs, documents=[doc.model_dump() for doc in documents], start_time=timer.get("start"), end_time=timer.get("end"), @@ -569,9 +580,9 @@ class TraceTask: return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get('tool_name') - tool_inputs = kwargs.get('tool_inputs') - tool_outputs = kwargs.get('tool_outputs') + tool_name = kwargs.get("tool_name") + tool_inputs = kwargs.get("tool_inputs") + tool_outputs = kwargs.get("tool_outputs") message_data = get_message_data(message_id) if not message_data: return {} @@ -586,11 +597,11 @@ class TraceTask: if tool_name in agent_thought.tools: created_time = agent_thought.created_at tool_meta_data = agent_thought.tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - time_cost = tool_meta_data.get('time_cost', 0) + tool_config = tool_meta_data.get("tool_config", {}) + time_cost = tool_meta_data.get("time_cost", 0) end_time = created_time + timedelta(seconds=time_cost) - error = tool_meta_data.get('error', "") - tool_parameters = tool_meta_data.get('tool_parameters', {}) + error = tool_meta_data.get("error", "") + tool_parameters = tool_meta_data.get("tool_parameters", {}) metadata = { "message_id": message_id, "tool_name": tool_name, @@ -684,8 +695,7 @@ class TraceQueueManager: self.start_timer() def add_trace_task(self, trace_task: TraceTask): - global trace_manager_timer - global trace_manager_queue + global trace_manager_timer, trace_manager_queue try: if self.trace_instance: trace_task.app_id = self.app_id @@ -715,9 +725,7 @@ class TraceQueueManager: def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer( - trace_manager_interval, self.run - ) + trace_manager_timer = threading.Timer(trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 3b2e04abb7..498685b342 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -20,19 +20,19 @@ def get_message_data(message_id): @contextmanager def measure_time(): - timing_info = {'start': datetime.now(), 'end': None} + timing_info = {"start": datetime.now(), "end": None} try: yield timing_info finally: - timing_info['end'] = datetime.now() + timing_info["end"] = datetime.now() def replace_text_with_content(data): if isinstance(data, dict): new_data = {} for key, value in data.items(): - if key == 'text': - new_data['content'] = value + if key == "text": + new_data["content"] = value else: new_data[key] = replace_text_with_content(value) return new_data diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index b1e1065aff..a98f4169a7 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -17,14 +17,15 @@ from models.model import App, AppMode, EndUser class PluginAppBackwardsInvocation(BaseBackwardsInvocation): @classmethod def invoke_app( - cls, app_id: str, - user_id: str, + cls, + app_id: str, + user_id: str, tenant_id: str, conversation_id: Optional[str], query: Optional[str], stream: bool, inputs: Mapping, - files: list[dict], + files: list[dict], ) -> Generator[dict | str, None, None] | dict: """ invoke app @@ -37,28 +38,28 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in [AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value]: + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}: if not query: raise ValueError("missing query") - + return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) - elif app.mode in [AppMode.WORKFLOW.value]: + elif app.mode == AppMode.WORKFLOW.value: return cls.invoke_workflow_app(app, user, stream, inputs, files) - elif app.mode in [AppMode.COMPLETION]: + elif app.mode == AppMode.COMPLETION: return cls.invoke_completion_app(app, user, stream, inputs, files) raise ValueError("unexpected app type") @classmethod def invoke_chat_app( - cls, + cls, app: App, - user: Account | EndUser, + user: Account | EndUser, conversation_id: str, query: str, stream: bool, inputs: Mapping, - files: list[dict], + files: list[dict], ) -> Generator[dict | str, None, None] | dict: """ invoke chat app @@ -67,11 +68,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): workflow = app.workflow if not workflow: raise ValueError("unexpected app type") - + return AdvancedChatAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, + app_model=app, + workflow=workflow, + user=user, args={ "inputs": inputs, "query": query, @@ -79,12 +80,12 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): "conversation_id": conversation_id, }, invoke_from=InvokeFrom.SERVICE_API, - stream=stream + stream=stream, ) elif app.mode == AppMode.AGENT_CHAT.value: return AgentChatAppGenerator().generate( - app_model=app, - user=user, + app_model=app, + user=user, args={ "inputs": inputs, "query": query, @@ -92,12 +93,12 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): "conversation_id": conversation_id, }, invoke_from=InvokeFrom.SERVICE_API, - stream=stream + stream=stream, ) elif app.mode == AppMode.CHAT.value: return ChatAppGenerator().generate( - app_model=app, - user=user, + app_model=app, + user=user, args={ "inputs": inputs, "query": query, @@ -105,19 +106,19 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): "conversation_id": conversation_id, }, invoke_from=InvokeFrom.SERVICE_API, - stream=stream + stream=stream, ) else: raise ValueError("unexpected app type") @classmethod def invoke_workflow_app( - cls, + cls, app: App, - user: EndUser | Account, + user: EndUser | Account, stream: bool, inputs: Mapping, - files: list[dict], + files: list[dict], ): """ invoke workflow app @@ -127,13 +128,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): raise ValueError("") return WorkflowAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args={ - 'inputs': inputs, - 'files': files - }, + app_model=app, + workflow=workflow, + user=user, + args={"inputs": inputs, "files": files}, invoke_from=InvokeFrom.SERVICE_API, stream=stream, call_depth=1, @@ -141,23 +139,20 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): @classmethod def invoke_completion_app( - cls, + cls, app: App, - user: EndUser | Account, + user: EndUser | Account, stream: bool, inputs: Mapping, - files: list[dict], + files: list[dict], ): """ invoke completion app """ return CompletionAppGenerator().generate( - app_model=app, - user=user, - args={ - 'inputs': inputs, - 'files': files - }, + app_model=app, + user=user, + args={"inputs": inputs, "files": files}, invoke_from=InvokeFrom.SERVICE_API, stream=stream, ) @@ -173,24 +168,21 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): user = db.session.query(Account).filter(Account.id == user_id).first() if not user: - raise ValueError('user not found') + raise ValueError("user not found") return user - + @classmethod def _get_app(cls, app_id: str, tenant_id: str) -> App: """ get app """ try: - app = db.session.query(App). \ - filter(App.id == app_id). \ - filter(App.tenant_id == tenant_id). \ - first() + app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first() except Exception: raise ValueError("app not found") - + if not app: raise ValueError("app not found") - - return app \ No newline at end of file + + return app diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 78414e2aa5..627b335225 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -32,6 +32,7 @@ class RequestInvokeLLM(BaseRequestInvokeModel): """ Request to invoke LLM """ + model_type: ModelType = ModelType.LLM mode: str model_parameters: dict[str, Any] = Field(default_factory=dict) @@ -40,19 +41,19 @@ class RequestInvokeLLM(BaseRequestInvokeModel): stop: Optional[list[str]] = Field(default_factory=list) stream: Optional[bool] = False - @field_validator('prompt_messages', mode='before') - def convert_prompt_messages(cls, v): + @field_validator("prompt_messages", mode="before") + def convert_prompt_messages(self, v): if not isinstance(v, list): - raise ValueError('prompt_messages must be a list') + raise ValueError("prompt_messages must be a list") for i in range(len(v)): - if v[i]['role'] == PromptMessageRole.USER.value: + if v[i]["role"] == PromptMessageRole.USER.value: v[i] = UserPromptMessage(**v[i]) - elif v[i]['role'] == PromptMessageRole.ASSISTANT.value: + elif v[i]["role"] == PromptMessageRole.ASSISTANT.value: v[i] = AssistantPromptMessage(**v[i]) - elif v[i]['role'] == PromptMessageRole.SYSTEM.value: + elif v[i]["role"] == PromptMessageRole.SYSTEM.value: v[i] = SystemPromptMessage(**v[i]) - elif v[i]['role'] == PromptMessageRole.TOOL.value: + elif v[i]["role"] == PromptMessageRole.TOOL.value: v[i] = ToolPromptMessage(**v[i]) else: v[i] = PromptMessage(**v[i]) @@ -95,10 +96,12 @@ class RequestInvokeNode(BaseModel): Request to invoke node """ + class RequestInvokeApp(BaseModel): """ Request to invoke app """ + app_id: str inputs: dict[str, Any] query: Optional[str] = None @@ -107,12 +110,14 @@ class RequestInvokeApp(BaseModel): user: Optional[str] = None files: list[dict] = Field(default_factory=list) + class RequestInvokeEncrypt(BaseModel): """ Request to encryption """ + opt: Literal["encrypt", "decrypt"] namespace: Literal["endpoint"] identity: str data: dict = Field(default_factory=dict) - config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping) \ No newline at end of file + config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 22420fea2c..ce8038d14e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -22,18 +22,22 @@ class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ + def __init__(self, with_variable_tmpl: bool = False) -> None: self.with_variable_tmpl = with_variable_tmpl - def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, - query: str, - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def get_prompt( + self, + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], + inputs: dict, + query: str, + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: inputs = {key: str(value) for key, value in inputs.items()} prompt_messages = [] @@ -48,7 +52,7 @@ class AdvancedPromptTransform(PromptTransform): context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( @@ -60,20 +64,22 @@ class AdvancedPromptTransform(PromptTransform): context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _get_completion_model_prompt_messages(self, - prompt_template: CompletionModelPromptTemplate, - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _get_completion_model_prompt_messages( + self, + prompt_template: CompletionModelPromptTemplate, + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -81,7 +87,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] - if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: + if prompt_template.edition_type == "basic" or not prompt_template.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -96,15 +102,13 @@ class AdvancedPromptTransform(PromptTransform): role_prefix=role_prefix, prompt_template=prompt_template, prompt_inputs=prompt_inputs, - model_config=model_config + model_config=model_config, ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = prompt_template.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -122,16 +126,18 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages - def _get_chat_model_prompt_messages(self, - prompt_template: list[ChatModelMessage], - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def _get_chat_model_prompt_messages( + self, + prompt_template: list[ChatModelMessage], + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -142,22 +148,20 @@ class AdvancedPromptTransform(PromptTransform): for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: + if prompt_item.edition_type == "basic" or not prompt_item.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) - elif prompt_item.edition_type == 'jinja2': + prompt = prompt_template.format(prompt_inputs) + elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs prompt = Jinja2Formatter.format(prompt, prompt_inputs) else: - raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') + raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) @@ -168,17 +172,14 @@ class AdvancedPromptTransform(PromptTransform): if query and query_prompt_template: prompt_template = PromptTemplateParser( - template=query_prompt_template, - with_variable_tmpl=self.with_variable_tmpl + template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs['#sys.query#'] = query + prompt_inputs["#sys.query#"] = query prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - query = prompt_template.format( - prompt_inputs - ) + query = prompt_template.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) @@ -203,7 +204,7 @@ class AdvancedPromptTransform(PromptTransform): last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -220,38 +221,39 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#context#' in prompt_template.variable_keys: + if "#context#" in prompt_template.variable_keys: if context: - prompt_inputs['#context#'] = context + prompt_inputs["#context#"] = context else: - prompt_inputs['#context#'] = '' + prompt_inputs["#context#"] = "" return prompt_inputs def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#query#' in prompt_template.variable_keys: + if "#query#" in prompt_template.variable_keys: if query: - prompt_inputs['#query#'] = query + prompt_inputs["#query#"] = query else: - prompt_inputs['#query#'] = '' + prompt_inputs["#query#"] = "" return prompt_inputs - def _set_histories_variable(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - raw_prompt: str, - role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigWithCredentialsEntity) -> dict: - if '#histories#' in prompt_template.variable_keys: + def _set_histories_variable( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + raw_prompt: str, + role_prefix: MemoryConfig.RolePrefix, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigWithCredentialsEntity, + ) -> dict: + if "#histories#" in prompt_template.variable_keys: if memory: - inputs = {'#histories#': '', **prompt_inputs} + inputs = {"#histories#": "", **prompt_inputs} prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage( - content=prompt_template.format(prompt_inputs) - ) + tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) @@ -260,10 +262,10 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant + ai_prefix=role_prefix.assistant, ) - prompt_inputs['#histories#'] = histories + prompt_inputs["#histories#"] = histories else: - prompt_inputs['#histories#'] = '' + prompt_inputs["#histories#"] = "" return prompt_inputs diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index af0075ea91..caa1793ea8 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform): """ History Prompt Transform for Agent App """ - def __init__(self, - model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage], - history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, - ): + + def __init__( + self, + model_config: ModelConfigWithCredentialsEntity, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], + memory: Optional[TokenBufferMemory] = None, + ): self.model_config = model_config self.prompt_messages = prompt_messages self.history_messages = history_messages @@ -45,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform): model_type_instance = cast(LargeLanguageModel, model_type_instance) curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - self.history_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages ) if curr_message_tokens <= max_token_limit: return self.history_messages @@ -63,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform): # a message is start with UserPromptMessage if isinstance(prompt_message, UserPromptMessage): curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - prompt_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages ) # if current message token is overflow, drop all the prompts in current message and break if curr_message_tokens > max_token_limit: diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 61df69163c..c8e7b414df 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel): """ Chat Message. """ + text: str role: PromptMessageRole - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class CompletionModelPromptTemplate(BaseModel): """ Completion Model Prompt Template. """ + text: str - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class MemoryConfig(BaseModel): """ Memory Config. """ + class RolePrefix(BaseModel): """ Role Prefix. """ + user: str assistant: str @@ -37,6 +41,7 @@ class MemoryConfig(BaseModel): """ Window Config. """ + enabled: bool size: Optional[int] = None diff --git a/api/core/prompt/prompt_templates/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py index da40534d99..0ab7f526cc 100644 --- a/api/core/prompt/prompt_templates/advanced_prompt_templates.py +++ b/api/core/prompt/prompt_templates/advanced_prompt_templates.py @@ -1,83 +1,45 @@ -CONTEXT = "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" +CONTEXT = "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" # noqa: E501 -BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" +BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" # noqa: E501 CHAT_APP_COMPLETION_PROMPT_CONFIG = { "completion_prompt_config": { "prompt": { - "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " + "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501 }, - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - } + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, }, - "stop": ["Human:"] + "stop": ["Human:"], } -CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } -} +CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}} -COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } -} +COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}} COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["Human:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["Human:"], } BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { "completion_prompt_config": { "prompt": { - "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" + "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" # noqa: E501 }, - "conversation_histories_role": { - "user_prefix": "用户", - "assistant_prefix": "助手" - } + "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, }, - "stop": ["用户:"] + "stop": ["用户:"], } -BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } + "chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["用户:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["用户:"], } diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index b86d3fa815..87acdb3c49 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -9,75 +9,78 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: - def _append_chat_histories(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _append_chat_histories( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: + def _get_history_messages_from_memory( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None, + ) -> str: """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } + kwargs = {"max_token_limit": max_token_limit} if human_prefix: - kwargs['human_prefix'] = human_prefix + kwargs["human_prefix"] = human_prefix if ai_prefix: - kwargs['ai_prefix'] = ai_prefix + kwargs["ai_prefix"] = ai_prefix if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: - kwargs['message_limit'] = memory_config.window.size + kwargs["message_limit"] = memory_config.window.size - return memory.get_history_prompt_text( - **kwargs - ) + return memory.get_history_prompt_text(**kwargs) - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int) -> list[PromptMessage]: + def _get_history_messages_list_from_memory( + self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int + ) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit, message_limit=memory_config.window.size - if (memory_config.window.enabled - and memory_config.window.size is not None - and memory_config.window.size > 0) - else None + if ( + memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + ) + else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fd7ed0181b..7479560520 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -22,11 +22,11 @@ if TYPE_CHECKING: class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' + COMPLETION = "completion" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'ModelMode': + def value_of(cls, value: str) -> "ModelMode": """ Get value of given mode. @@ -36,7 +36,7 @@ class ModelMode(enum.Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") prompt_file_contents = {} @@ -47,16 +47,17 @@ class SimplePromptTransform(PromptTransform): Simple Prompt Transform for Chatbot App Basic Mode. """ - def get_prompt(self, - app_mode: AppMode, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: + def get_prompt( + self, + app_mode: AppMode, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode.value_of(model_config.mode) @@ -69,7 +70,7 @@ class SimplePromptTransform(PromptTransform): files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -80,19 +81,21 @@ class SimplePromptTransform(PromptTransform): files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages, stops - def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigWithCredentialsEntity, - pre_prompt: str, - inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, - ) -> tuple[str, dict]: + def get_prompt_str_and_rules( + self, + app_mode: AppMode, + model_config: ModelConfigWithCredentialsEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -101,74 +104,75 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, has_context=context is not None, query_in_prompt=query is not None, - with_memory_prompt=histories is not None + with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} - for v in prompt_template_config['special_variable_keys']: + for v in prompt_template_config["special_variable_keys"]: # support #context#, #query# and #histories# - if v == '#context#': - variables['#context#'] = context if context else '' - elif v == '#query#': - variables['#query#'] = query if query else '' - elif v == '#histories#': - variables['#histories#'] = histories if histories else '' + if v == "#context#": + variables["#context#"] = context or "" + elif v == "#query#": + variables["#query#"] = query or "" + elif v == "#histories#": + variables["#histories#"] = histories or "" - prompt_template = prompt_template_config['prompt_template'] + prompt_template = prompt_template_config["prompt_template"] prompt = prompt_template.format(variables) - return prompt, prompt_template_config['prompt_rules'] + return prompt, prompt_template_config["prompt_rules"] - def get_prompt_template(self, app_mode: AppMode, - provider: str, - model: str, - pre_prompt: str, - has_context: bool, - query_in_prompt: bool, - with_memory_prompt: bool = False) -> dict: - prompt_rules = self._get_prompt_rule( - app_mode=app_mode, - provider=provider, - model=model - ) + def get_prompt_template( + self, + app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False, + ) -> dict: + prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys = [] special_variable_keys = [] - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt' and has_context: - prompt += prompt_rules['context_prompt'] - special_variable_keys.append('#context#') - elif order == 'pre_prompt' and pre_prompt: - prompt += pre_prompt + '\n' + prompt = "" + for order in prompt_rules["system_prompt_orders"]: + if order == "context_prompt" and has_context: + prompt += prompt_rules["context_prompt"] + special_variable_keys.append("#context#") + elif order == "pre_prompt" and pre_prompt: + prompt += pre_prompt + "\n" pre_prompt_template = PromptTemplateParser(template=pre_prompt) custom_variable_keys = pre_prompt_template.variable_keys - elif order == 'histories_prompt' and with_memory_prompt: - prompt += prompt_rules['histories_prompt'] - special_variable_keys.append('#histories#') + elif order == "histories_prompt" and with_memory_prompt: + prompt += prompt_rules["histories_prompt"] + special_variable_keys.append("#histories#") if query_in_prompt: - prompt += prompt_rules.get('query_prompt', '{{#query#}}') - special_variable_keys.append('#query#') + prompt += prompt_rules.get("query_prompt", "{{#query#}}") + special_variable_keys.append("#query#") return { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, - "prompt_rules": prompt_rules + "prompt_rules": prompt_rules, } - def _get_chat_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_chat_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -178,7 +182,7 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, inputs=inputs, query=None, - context=context + context=context, ) if prompt and query: @@ -193,7 +197,7 @@ class SimplePromptTransform(PromptTransform): ) ), prompt_messages=prompt_messages, - model_config=model_config + model_config=model_config, ) if query: @@ -203,15 +207,17 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, None - def _get_completion_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_completion_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=app_mode, @@ -219,13 +225,11 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, inputs=inputs, query=query, - context=context + context=context, ) if memory: - tmp_human_message = UserPromptMessage( - content=prompt - ) + tmp_human_message = UserPromptMessage(content=prompt) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( @@ -236,8 +240,8 @@ class SimplePromptTransform(PromptTransform): ) ), max_token_limit=rest_tokens, - human_prefix=prompt_rules.get('human_prefix', 'Human'), - ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ) # get prompt @@ -248,10 +252,10 @@ class SimplePromptTransform(PromptTransform): inputs=inputs, query=query, context=context, - histories=histories + histories=histories, ) - stops = prompt_rules.get('stops') + stops = prompt_rules.get("stops") if stops is not None and len(stops) == 0: stops = None @@ -277,22 +281,18 @@ class SimplePromptTransform(PromptTransform): :param model: model name :return: """ - prompt_file_name = self._prompt_file_name( - app_mode=app_mode, - provider=provider, - model=model - ) + prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model) # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') - json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") + json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json") # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: + with open(json_file_path, encoding="utf-8") as json_file: content = json.load(json_file) # Store the content of the prompt file @@ -303,21 +303,21 @@ class SimplePromptTransform(PromptTransform): def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False - if provider == 'baichuan': + if provider == "baichuan": is_baichuan = True else: baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + if provider in baichuan_supported_providers and "baichuan" in model.lower(): is_baichuan = True if is_baichuan: if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' + return "baichuan_completion" else: - return 'baichuan_chat' + return "baichuan_chat" # common if app_mode == AppMode.COMPLETION: - return 'common_completion' + return "common_completion" else: - return 'common_chat' + return "common_chat" diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index befdceeda5..29494db221 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -25,26 +25,29 @@ class PromptMessageUtil: tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: - role = 'user' + role = "user" elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' + role = "assistant" if isinstance(prompt_message, AssistantPromptMessage): - tool_calls = [{ - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, + tool_calls = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls] + for tool_call in prompt_message.tool_calls + ] elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' + role = "system" elif prompt_message.role == PromptMessageRole.TOOL: - role = 'tool' + role = "tool" else: continue - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -53,27 +56,25 @@ class PromptMessageUtil: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content - prompt = { - "role": role, - "text": text, - "files": files - } - + prompt = {"role": role, "text": text, "files": files} + if tool_calls: - prompt['tool_calls'] = tool_calls + prompt["tool_calls"] = tool_calls prompts.append(prompt) else: prompt_message = prompt_messages[0] - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -82,21 +83,23 @@ class PromptMessageUtil: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content params = { - "role": 'user', + "role": "user", "text": text, } if files: - params['files'] = files + params["files"] = files prompts.append(params) diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 3e68492df2..8111559675 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -38,8 +38,8 @@ class PromptTemplateParser: return value prompt = re.sub(self.regex, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False): - return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text) + return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 67eee2c294..3a1fe300df 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -90,8 +90,7 @@ class ProviderManager: # Initialize trial provider records if not exist provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, - provider_name_to_provider_records_dict + tenant_id, provider_name_to_provider_records_dict ) # Get all provider model records of the workspace @@ -107,22 +106,20 @@ class ProviderManager: provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) # Get All load balancing configs - provider_name_to_provider_load_balancing_model_configs_dict \ - = self._get_all_provider_load_balancing_configs(tenant_id) - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id + provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( + tenant_id ) + provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: - # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, - data=provider_entity, - name_func=lambda x: x.provider, + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, ): continue @@ -132,18 +129,11 @@ class ProviderManager: # Convert to custom configuration custom_configuration = self._to_custom_configuration( - tenant_id, - provider_entity, - provider_records, - provider_model_records + tenant_id, provider_entity, provider_records, provider_model_records ) # Convert to system configuration - system_configuration = self._to_system_configuration( - tenant_id, - provider_entity, - provider_records - ) + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) # Get preferred provider type preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) @@ -173,14 +163,15 @@ class ProviderManager: provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) # Get provider load balancing configs - provider_load_balancing_configs \ - = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_name + ) # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, provider_model_settings=provider_model_settings, - load_balancing_model_configs=provider_load_balancing_configs + load_balancing_model_configs=provider_load_balancing_configs, ) provider_configuration = ProviderConfiguration( @@ -190,7 +181,7 @@ class ProviderManager: using_provider_type=using_provider_type, system_configuration=system_configuration, custom_configuration=custom_configuration, - model_settings=model_settings + model_settings=model_settings, ) provider_configurations[provider_name] = provider_configuration @@ -219,7 +210,7 @@ class ProviderManager: return ProviderModelBundle( configuration=provider_configuration, provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: @@ -231,11 +222,14 @@ class ProviderManager: :return: """ # Get the corresponding TenantDefaultModel record - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -244,20 +238,18 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next((model for model in available_models if model.model == "gpt-4"), - available_models[0]) + available_model = next( + (model for model in available_models if model.model == "gpt-4"), available_models[0] + ) default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), provider_name=available_model.provider.provider, - model_name=available_model.model + model_name=available_model.model, ) db.session.add(default_model) db.session.commit() @@ -276,8 +268,8 @@ class ProviderManager: label=provider_schema.label, icon_small=provider_schema.icon_small, icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) + supported_model_types=provider_schema.supported_model_types, + ), ) def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: @@ -291,15 +283,13 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - all_models = provider_configurations.get_models( - model_type=model_type, - only_active=False - ) + all_models = provider_configurations.get_models(model_type=model_type, only_active=False) return all_models[0].provider.provider, all_models[0].model - def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ - -> TenantDefaultModel: + def update_default_model_record( + self, tenant_id: str, model_type: ModelType, provider: str, model: str + ) -> TenantDefaultModel: """ Update default model record. @@ -314,10 +304,7 @@ class ProviderManager: raise ValueError(f"Provider {provider} does not exist.") # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) # check if the model is exist in available models model_names = [model.model for model in available_models] @@ -325,11 +312,14 @@ class ProviderManager: raise ValueError(f"Model {model} does not exist.") # Get the list of available models from get_configurations and check if it is LLM - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # create or update TenantDefaultModel record if default_model: @@ -358,11 +348,7 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.is_valid == True - ).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -379,11 +365,11 @@ class ProviderManager: :return: """ # Get all provider model records of the workspace - provider_models = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == tenant_id, - ProviderModel.is_valid == True - ).all() + provider_models = ( + db.session.query(ProviderModel) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + .all() + ) provider_name_to_provider_model_records_dict = defaultdict(list) for provider_model in provider_models: @@ -399,10 +385,11 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == tenant_id - ).all() + preferred_provider_types = ( + db.session.query(TenantPreferredModelProvider) + .filter(TenantPreferredModelProvider.tenant_id == tenant_id) + .all() + ) provider_name_to_preferred_provider_type_records_dict = { preferred_provider_type.provider_name: preferred_provider_type @@ -419,15 +406,17 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = db.session.query(ProviderModelSetting) \ - .filter( - ProviderModelSetting.tenant_id == tenant_id - ).all() + provider_model_settings = ( + db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() + ) provider_name_to_provider_model_settings_dict = defaultdict(list) for provider_model_setting in provider_model_settings: - (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] - .append(provider_model_setting)) + ( + provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( + provider_model_setting + ) + ) return provider_name_to_provider_model_settings_dict @@ -445,27 +434,30 @@ class ProviderManager: model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) else: - cache_result = cache_result.decode('utf-8') - model_load_balancing_enabled = cache_result == 'True' + cache_result = cache_result.decode("utf-8") + model_load_balancing_enabled = cache_result == "True" if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ - .filter( - LoadBalancingModelConfig.tenant_id == tenant_id - ).all() + provider_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() + ) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) for provider_load_balancing_config in provider_load_balancing_configs: - (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] - .append(provider_load_balancing_config)) + ( + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) + ) return provider_name_to_provider_load_balancing_model_configs_dict @staticmethod - def _init_trial_provider_records(tenant_id: str, - provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: + def _init_trial_provider_records( + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] + ) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -489,8 +481,9 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) for quota in configuration.quotas: if quota.quota_type == ProviderQuotaType.TRIAL: @@ -504,19 +497,22 @@ class ProviderManager: quota_type=ProviderQuotaType.TRIAL.value, quota_limit=quota.quota_limit, quota_used=0, - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() except IntegrityError: db.session.rollback() - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value - ).first() + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + ) + .first() + ) if provider_record and not provider_record.is_valid: provider_record.is_valid = True @@ -526,11 +522,13 @@ class ProviderManager: return provider_name_to_provider_records_dict - def _to_custom_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider], - provider_model_records: list[ProviderModel]) -> CustomConfiguration: + def _to_custom_configuration( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_records: list[Provider], + provider_model_records: list[ProviderModel], + ) -> CustomConfiguration: """ Convert to custom configuration. @@ -543,7 +541,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get custom provider record @@ -563,7 +562,7 @@ class ProviderManager: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -572,11 +571,11 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } + if ( + custom_provider_record.encrypted_config + and not custom_provider_record.encrypted_config.startswith("{") + ): + provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) except JSONDecodeError: @@ -590,28 +589,23 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass # cache provider credentials - provider_credentials_cache.set( - credentials=provider_credentials - ) + provider_credentials_cache.set(credentials=provider_credentials) else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials - ) + custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) # Get custom provider model credentials @@ -621,9 +615,7 @@ class ProviderManager: continue provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL ) # Get cached provider model credentials @@ -645,15 +637,13 @@ class ProviderManager: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials @@ -661,19 +651,15 @@ class ProviderManager: CustomModelConfiguration( model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), - credentials=provider_model_credentials + credentials=provider_model_credentials, ) ) - return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations - ) + return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) - def _to_system_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider]) -> SystemConfiguration: + def _to_system_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> SystemConfiguration: """ Convert to system configuration. @@ -685,11 +671,11 @@ class ProviderManager: # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if provider_entity.provider not in hosting_configuration.provider_map \ - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled: - return SystemConfiguration( - enabled=False - ) + if ( + provider_entity.provider not in hosting_configuration.provider_map + or not hosting_configuration.provider_map.get(provider_entity.provider).enabled + ): + return SystemConfiguration(enabled=False) provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) @@ -699,8 +685,9 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: @@ -712,7 +699,7 @@ class ProviderManager: quota_used=0, quota_limit=0, is_valid=False, - restrict_models=provider_quota.restrict_models + restrict_models=provider_quota.restrict_models, ) else: continue @@ -724,16 +711,15 @@ class ProviderManager: quota_unit=provider_hosting_configuration.quota_unit, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, ) quota_configurations.append(quota_configuration) if len(quota_configurations) == 0: - return SystemConfiguration( - enabled=False - ) + return SystemConfiguration(enabled=False) current_quota_type = self._choice_current_using_quota_type(quota_configurations) @@ -745,7 +731,7 @@ class ProviderManager: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -760,7 +746,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get decoding rsa key and cipher for decrypting credentials @@ -771,9 +758,7 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass @@ -781,9 +766,7 @@ class ProviderManager: current_using_credentials = provider_credentials # cache provider credentials - provider_credentials_cache.set( - credentials=current_using_credentials - ) + provider_credentials_cache.set(credentials=current_using_credentials) else: current_using_credentials = cached_provider_credentials else: @@ -794,7 +777,7 @@ class ProviderManager: enabled=True, current_quota_type=current_quota_type, quota_configurations=quota_configurations, - credentials=current_using_credentials + credentials=current_using_credentials, ) @staticmethod @@ -809,8 +792,7 @@ class ProviderManager: """ # convert to dict quota_type_to_quota_configuration_dict = { - quota_configuration.quota_type: quota_configuration - for quota_configuration in quota_configurations + quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations } last_quota_configuration = None @@ -823,7 +805,7 @@ class ProviderManager: if last_quota_configuration: return last_quota_configuration.quota_type - raise ValueError('No quota type available') + raise ValueError("No quota type available") @staticmethod def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: @@ -840,10 +822,12 @@ class ProviderManager: return secret_input_form_variables - def _to_model_settings(self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ - -> list[ModelSettings]: + def _to_model_settings( + self, + provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + ) -> list[ModelSettings]: """ Convert to model settings. :param provider_entity: provider entity @@ -854,7 +838,8 @@ class ProviderManager: # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) model_settings = [] @@ -865,24 +850,28 @@ class ProviderManager: load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: for load_balancing_model_config in load_balancing_model_configs: - if (load_balancing_model_config.model_name == provider_model_setting.model_name - and load_balancing_model_config.model_type == provider_model_setting.model_type): + if ( + load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type + ): if not load_balancing_model_config.enabled: continue if not load_balancing_model_config.encrypted_config: if load_balancing_model_config.name == "__inherit__": - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={} - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + ) + ) continue provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=load_balancing_model_config.tenant_id, identity_id=load_balancing_model_config.id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) # Get cached provider model credentials @@ -897,7 +886,8 @@ class ProviderManager: # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( - load_balancing_model_config.tenant_id) + load_balancing_model_config.tenant_id + ) for variable in model_credential_secret_variables: if variable in provider_model_credentials: @@ -905,30 +895,30 @@ class ProviderManager: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials=provider_model_credentials - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials, + ) + ) model_settings.append( ModelSettings( model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, - load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index eaad0e0f4c..3c6ab2e4cf 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -2,37 +2,35 @@ import re class CleanProcessor: - @classmethod def clean(cls, text: str, process_rule: dict) -> str: # default clean # remove invalid symbol - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) - rules = process_rule['rules'] if process_rule else None - if 'pre_processing_rules' in rules: + rules = process_rule["rules"] if process_rule else None + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text def filter_string(self, text): - return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py index 523bd904f2..d3bc2f765e 100644 --- a/api/core/rag/cleaner/cleaner_base.py +++ b/api/core/rag/cleaner/cleaner_base.py @@ -1,12 +1,11 @@ """Abstract interface for document cleaner implementations.""" + from abc import ABC, abstractmethod class BaseCleaner(ABC): - """Interface for clean chunk content. - """ + """Interface for clean chunk content.""" @abstractmethod def clean(self, content: str): raise NotImplementedError - diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py index 6a0b8c9046..167a919e69 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_extra_whitespace diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py index 6fc3a408da..9c682d29db 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" import re diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py index 87dc2d49fa..0cdbb171e1 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_non_ascii_chars diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py index 974a28fef1..9f42044a2d 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -1,11 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """Replaces unicode quote characters, such as the \x91 character in a string.""" from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py index dfaf3a2787..32ae7217e8 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.translate import translate_text diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index ad9ee4f7cf..b1d6f93cff 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -12,17 +12,27 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner class DataPostProcessor: - """Interface for data post-processing document. - """ + """Interface for data post-processing document.""" - def __init__(self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - reorder_enabled: bool = False): + def __init__( + self, + tenant_id: str, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reorder_enabled: bool = False, + ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) self.reorder_runner = self._get_reorder_runner(reorder_enabled) - def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def invoke( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -31,21 +41,26 @@ class DataPostProcessor: return documents - def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]: + def _get_rerank_runner( + self, + reranking_mode: str, + tenant_id: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + ) -> Optional[RerankModelRunner | WeightRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: return WeightRerankRunner( tenant_id, Weights( vector_setting=VectorSetting( - vector_weight=weights['vector_setting']['vector_weight'], - embedding_provider_name=weights['vector_setting']['embedding_provider_name'], - embedding_model_name=weights['vector_setting']['embedding_model_name'], + vector_weight=weights["vector_setting"]["vector_weight"], + embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], + embedding_model_name=weights["vector_setting"]["embedding_model_name"], ), keyword_setting=KeywordSetting( - keyword_weight=weights['keyword_setting']['keyword_weight'], - ) - ) + keyword_weight=weights["keyword_setting"]["keyword_weight"], + ), + ), ) elif reranking_mode == RerankMode.RERANKING_MODEL.value: if reranking_model: @@ -53,9 +68,9 @@ class DataPostProcessor: model_manager = ModelManager() rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model['reranking_provider_name'], + provider=reranking_model["reranking_provider_name"], model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] + model=reranking_model["reranking_model_name"], ) except InvokeAuthorizationError: return None @@ -67,5 +82,3 @@ class DataPostProcessor: if reorder_enabled: return ReorderRunner() return None - - diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py index 71297588a4..a9a0885241 100644 --- a/api/core/rag/data_post_processor/reorder.py +++ b/api/core/rag/data_post_processor/reorder.py @@ -2,7 +2,6 @@ from core.rag.models.document import Document class ReorderRunner: - def run(self, documents: list[Document]) -> list[Document]: # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list odd_elements = documents[::2] diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a3714c2fd3..3073100746 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -24,37 +24,42 @@ class Jieba(BaseKeyword): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) return self def add_texts(self, texts: list[Document], **kwargs): - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get('keywords_list', None) + keywords_list = kwargs.get("keywords_list", None) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords(text.page_content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) else: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) @@ -63,97 +68,91 @@ class Jieba(BaseKeyword): return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - k = kwargs.get('top_k', 4) + k = kwargs.get("top_k", 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) documents = [] for chunk_index in sorted_chunk_indices: - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset.id, - DocumentSegment.index_node_id == chunk_index - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) + .first() + ) if segment: - - documents.append(Document( - page_content=segment.content, - metadata={ - "doc_id": chunk_index, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - )) + documents.append( + Document( + page_content=segment.content, + metadata={ + "doc_id": chunk_index, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + ) return documents def delete(self) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: db.session.delete(dataset_keyword_table) db.session.commit() - if dataset_keyword_table.data_source_type != 'database': - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + if dataset_keyword_table.data_source_type != "database": + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": keyword_table - } + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type - if keyword_data_source_type == 'database': + if keyword_data_source_type == "database": dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) db.session.commit() else: - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) + storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict['__data__']['table'] + return keyword_table_dict["__data__"]["table"] else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( dataset_id=self.dataset.id, - keyword_table='', + keyword_table="", data_source_type=keyword_data_source_type, ) - if keyword_data_source_type == 'database': - dataset_keyword_table.keyword_table = json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) + if keyword_data_source_type == "database": + dataset_keyword_table.keyword_table = json.dumps( + { + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, + }, + cls=SetEncoder, + ) db.session.add(dataset_keyword_table) db.session.commit() @@ -174,9 +173,7 @@ class Jieba(BaseKeyword): keywords_to_delete = set() for keyword, node_idxs in keyword_table.items(): if node_idxs_to_delete.intersection(node_idxs): - keyword_table[keyword] = node_idxs.difference( - node_idxs_to_delete - ) + keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete) if not keyword_table[keyword]: keywords_to_delete.add(keyword) @@ -202,13 +199,14 @@ class Jieba(BaseKeyword): reverse=True, ) - return sorted_chunk_indices[: k] + return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id == node_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .first() + ) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) @@ -224,14 +222,14 @@ class Jieba(BaseKeyword): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] - if pre_segment_data['keywords']: - segment.keywords = pre_segment_data['keywords'] - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, - pre_segment_data['keywords']) + segment = pre_segment_data["segment"] + if pre_segment_data["keywords"]: + segment.keywords = pre_segment_data["keywords"] + keyword_table = self._add_text_to_keyword_table( + keyword_table, segment.index_node_id, pre_segment_data["keywords"] + ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ad669ef515..4b1ade8e3f 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS class JiebaKeywordTableHandler: - def __init__(self): default_tfidf.stop_words = STOPWORDS @@ -30,4 +29,4 @@ class JiebaKeywordTableHandler: if len(sub_tokens) > 1: results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - return results \ No newline at end of file + return results diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index c616a15cf0..9abe78d6ef 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -1,90 +1,1380 @@ STOPWORDS = { - "during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've", - "ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her", - "an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t", - "theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven", - "for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now", - "their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which", - "m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't", - "such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves", - "been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because", - "down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't", - "as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after", - "over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there", - "himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below", - "人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本", - "本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但", - "不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着", - "趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当", - "当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是", - "而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各", - "各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈", - "哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之", - "或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既", - "既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着", - "进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可", - "可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外", - "另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪", - "哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么", - "那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦", - "呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但", - "恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何", - "如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候", - "什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着", - "所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同", - "同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎", - "无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然", - "要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照", - "矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由", - "由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是", - "云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着", - "者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样", - "正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从", - "自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令", - "纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":", - "\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_", - "+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…", - " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二", - "三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[", - "]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时", - "按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加", - "本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没", - "并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭", - "不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了", - "不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要", - "不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止", - "不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去", - "长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心", - "乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去", - "除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹", - "此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速", - "从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从", - "打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事", - "大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场", - "当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿", - "到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独", - "独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等", - "二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得", - "分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢", - "敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为", - "公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底", - "归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须", - "何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然", - "活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆", - "即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之", - "简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此", - "借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量", - "尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外", - "举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非", - "均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲", - "来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻", - "立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次", - "屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢", - "每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪", - "难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通", - "其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧", - "恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自", - "顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然", - "人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述", - "如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n" + "during", + "when", + "but", + "then", + "further", + "isn", + "mustn't", + "until", + "own", + "i", + "couldn", + "y", + "only", + "you've", + "ours", + "who", + "where", + "ourselves", + "has", + "to", + "was", + "didn't", + "themselves", + "if", + "against", + "through", + "her", + "an", + "your", + "can", + "those", + "didn", + "about", + "aren't", + "shan't", + "be", + "not", + "these", + "again", + "so", + "t", + "theirs", + "weren", + "won't", + "won", + "itself", + "just", + "same", + "while", + "why", + "doesn", + "aren", + "him", + "haven", + "for", + "you'll", + "that", + "we", + "am", + "d", + "by", + "having", + "wasn't", + "than", + "weren't", + "out", + "from", + "now", + "their", + "too", + "hadn", + "o", + "needn", + "most", + "it", + "under", + "needn't", + "any", + "some", + "few", + "ll", + "hers", + "which", + "m", + "you're", + "off", + "other", + "had", + "she", + "you'd", + "do", + "you", + "does", + "s", + "will", + "each", + "wouldn't", + "hasn't", + "such", + "more", + "whom", + "she's", + "my", + "yours", + "yourself", + "of", + "on", + "very", + "hadn't", + "with", + "yourselves", + "been", + "ma", + "them", + "mightn't", + "shan", + "mustn", + "they", + "what", + "both", + "that'll", + "how", + "is", + "he", + "because", + "down", + "haven't", + "are", + "no", + "it's", + "our", + "being", + "the", + "or", + "above", + "myself", + "once", + "don't", + "doesn't", + "as", + "nor", + "here", + "herself", + "hasn", + "mightn", + "have", + "its", + "all", + "were", + "ain", + "this", + "at", + "after", + "over", + "shouldn't", + "into", + "before", + "don", + "wouldn", + "re", + "couldn't", + "wasn", + "in", + "should", + "there", + "himself", + "isn't", + "should've", + "doing", + "ve", + "shouldn", + "a", + "did", + "and", + "his", + "between", + "me", + "up", + "below", + "人民", + "末##末", + "啊", + "阿", + "哎", + "哎呀", + "哎哟", + "唉", + "俺", + "俺们", + "按", + "按照", + "吧", + "吧哒", + "把", + "罢了", + "被", + "本", + "本着", + "比", + "比方", + "比如", + "鄙人", + "彼", + "彼此", + "边", + "别", + "别的", + "别说", + "并", + "并且", + "不比", + "不成", + "不单", + "不但", + "不独", + "不管", + "不光", + "不过", + "不仅", + "不拘", + "不论", + "不怕", + "不然", + "不如", + "不特", + "不惟", + "不问", + "不只", + "朝", + "朝着", + "趁", + "趁着", + "乘", + "冲", + "除", + "除此之外", + "除非", + "除了", + "此", + "此间", + "此外", + "从", + "从而", + "打", + "待", + "但", + "但是", + "当", + "当着", + "到", + "得", + "的", + "的话", + "等", + "等等", + "地", + "第", + "叮咚", + "对", + "对于", + "多", + "多少", + "而", + "而况", + "而且", + "而是", + "而外", + "而言", + "而已", + "尔后", + "反过来", + "反过来说", + "反之", + "非但", + "非徒", + "否则", + "嘎", + "嘎登", + "该", + "赶", + "个", + "各", + "各个", + "各位", + "各种", + "各自", + "给", + "根据", + "跟", + "故", + "故此", + "固然", + "关于", + "管", + "归", + "果然", + "果真", + "过", + "哈", + "哈哈", + "呵", + "和", + "何", + "何处", + "何况", + "何时", + "嘿", + "哼", + "哼唷", + "呼哧", + "乎", + "哗", + "还是", + "还有", + "换句话说", + "换言之", + "或", + "或是", + "或者", + "极了", + "及", + "及其", + "及至", + "即", + "即便", + "即或", + "即令", + "即若", + "即使", + "几", + "几时", + "己", + "既", + "既然", + "既是", + "继而", + "加之", + "假如", + "假若", + "假使", + "鉴于", + "将", + "较", + "较之", + "叫", + "接着", + "结果", + "借", + "紧接着", + "进而", + "尽", + "尽管", + "经", + "经过", + "就", + "就是", + "就是说", + "据", + "具体地说", + "具体说来", + "开始", + "开外", + "靠", + "咳", + "可", + "可见", + "可是", + "可以", + "况且", + "啦", + "来", + "来着", + "离", + "例如", + "哩", + "连", + "连同", + "两者", + "了", + "临", + "另", + "另外", + "另一方面", + "论", + "嘛", + "吗", + "慢说", + "漫说", + "冒", + "么", + "每", + "每当", + "们", + "莫若", + "某", + "某个", + "某些", + "拿", + "哪", + "哪边", + "哪儿", + "哪个", + "哪里", + "哪年", + "哪怕", + "哪天", + "哪些", + "哪样", + "那", + "那边", + "那儿", + "那个", + "那会儿", + "那里", + "那么", + "那么些", + "那么样", + "那时", + "那些", + "那样", + "乃", + "乃至", + "呢", + "能", + "你", + "你们", + "您", + "宁", + "宁可", + "宁肯", + "宁愿", + "哦", + "呕", + "啪达", + "旁人", + "呸", + "凭", + "凭借", + "其", + "其次", + "其二", + "其他", + "其它", + "其一", + "其余", + "其中", + "起", + "起见", + "岂但", + "恰恰相反", + "前后", + "前者", + "且", + "然而", + "然后", + "然则", + "让", + "人家", + "任", + "任何", + "任凭", + "如", + "如此", + "如果", + "如何", + "如其", + "如若", + "如上所述", + "若", + "若非", + "若是", + "啥", + "上下", + "尚且", + "设若", + "设使", + "甚而", + "甚么", + "甚至", + "省得", + "时候", + "什么", + "什么样", + "使得", + "是", + "是的", + "首先", + "谁", + "谁知", + "顺", + "顺着", + "似的", + "虽", + "虽然", + "虽说", + "虽则", + "随", + "随着", + "所", + "所以", + "他", + "他们", + "他人", + "它", + "它们", + "她", + "她们", + "倘", + "倘或", + "倘然", + "倘若", + "倘使", + "腾", + "替", + "通过", + "同", + "同时", + "哇", + "万一", + "往", + "望", + "为", + "为何", + "为了", + "为什么", + "为着", + "喂", + "嗡嗡", + "我", + "我们", + "呜", + "呜呼", + "乌乎", + "无论", + "无宁", + "毋宁", + "嘻", + "吓", + "相对而言", + "像", + "向", + "向着", + "嘘", + "呀", + "焉", + "沿", + "沿着", + "要", + "要不", + "要不然", + "要不是", + "要么", + "要是", + "也", + "也罢", + "也好", + "一", + "一般", + "一旦", + "一方面", + "一来", + "一切", + "一样", + "一则", + "依", + "依照", + "矣", + "以", + "以便", + "以及", + "以免", + "以至", + "以至于", + "以致", + "抑或", + "因", + "因此", + "因而", + "因为", + "哟", + "用", + "由", + "由此可见", + "由于", + "有", + "有的", + "有关", + "有些", + "又", + "于", + "于是", + "于是乎", + "与", + "与此同时", + "与否", + "与其", + "越是", + "云云", + "哉", + "再说", + "再者", + "在", + "在下", + "咱", + "咱们", + "则", + "怎", + "怎么", + "怎么办", + "怎么样", + "怎样", + "咋", + "照", + "照着", + "者", + "这", + "这边", + "这儿", + "这个", + "这会儿", + "这就是说", + "这里", + "这么", + "这么点儿", + "这么些", + "这么样", + "这时", + "这些", + "这样", + "正如", + "吱", + "之", + "之类", + "之所以", + "之一", + "只是", + "只限", + "只要", + "只有", + "至", + "至于", + "诸位", + "着", + "着呢", + "自", + "自从", + "自个儿", + "自各儿", + "自己", + "自家", + "自身", + "综上所述", + "总的来看", + "总的来说", + "总的说来", + "总而言之", + "总之", + "纵", + "纵令", + "纵然", + "纵使", + "遵照", + "作为", + "兮", + "呃", + "呗", + "咚", + "咦", + "喏", + "啐", + "喔唷", + "嗬", + "嗯", + "嗳", + "~", + "!", + ".", + ":", + '"', + "'", + "(", + ")", + "*", + "A", + "白", + "社会主义", + "--", + "..", + ">>", + " [", + " ]", + "", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "+", + "=", + "&", + "^", + "%", + "#", + "@", + "`", + ";", + "$", + "(", + ")", + "——", + "—", + "¥", + "·", + "...", + "‘", + "’", + "〉", + "〈", + "…", + " ", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "零", + ">", + "<", + "@", + "#", + "$", + "%", + "︿", + "&", + "*", + "+", + "~", + "|", + "[", + "]", + "{", + "}", + "啊哈", + "啊呀", + "啊哟", + "挨次", + "挨个", + "挨家挨户", + "挨门挨户", + "挨门逐户", + "挨着", + "按理", + "按期", + "按时", + "按说", + "暗地里", + "暗中", + "暗自", + "昂然", + "八成", + "白白", + "半", + "梆", + "保管", + "保险", + "饱", + "背地里", + "背靠背", + "倍感", + "倍加", + "本人", + "本身", + "甭", + "比起", + "比如说", + "比照", + "毕竟", + "必", + "必定", + "必将", + "必须", + "便", + "别人", + "并非", + "并肩", + "并没", + "并没有", + "并排", + "并无", + "勃然", + "不", + "不必", + "不常", + "不大", + "不但...而且", + "不得", + "不得不", + "不得了", + "不得已", + "不迭", + "不定", + "不对", + "不妨", + "不管怎样", + "不会", + "不仅...而且", + "不仅仅", + "不仅仅是", + "不经意", + "不可开交", + "不可抗拒", + "不力", + "不了", + "不料", + "不满", + "不免", + "不能不", + "不起", + "不巧", + "不然的话", + "不日", + "不少", + "不胜", + "不时", + "不是", + "不同", + "不能", + "不要", + "不外", + "不外乎", + "不下", + "不限", + "不消", + "不已", + "不亦乐乎", + "不由得", + "不再", + "不择手段", + "不怎么", + "不曾", + "不知不觉", + "不止", + "不止一次", + "不至于", + "才", + "才能", + "策略地", + "差不多", + "差一点", + "常", + "常常", + "常言道", + "常言说", + "常言说得好", + "长此下去", + "长话短说", + "长期以来", + "长线", + "敞开儿", + "彻夜", + "陈年", + "趁便", + "趁机", + "趁热", + "趁势", + "趁早", + "成年", + "成年累月", + "成心", + "乘机", + "乘胜", + "乘势", + "乘隙", + "乘虚", + "诚然", + "迟早", + "充分", + "充其极", + "充其量", + "抽冷子", + "臭", + "初", + "出", + "出来", + "出去", + "除此", + "除此而外", + "除此以外", + "除开", + "除去", + "除却", + "除外", + "处处", + "川流不息", + "传", + "传说", + "传闻", + "串行", + "纯", + "纯粹", + "此后", + "此中", + "次第", + "匆匆", + "从不", + "从此", + "从此以后", + "从古到今", + "从古至今", + "从今以后", + "从宽", + "从来", + "从轻", + "从速", + "从头", + "从未", + "从无到有", + "从小", + "从新", + "从严", + "从优", + "从早到晚", + "从中", + "从重", + "凑巧", + "粗", + "存心", + "达旦", + "打从", + "打开天窗说亮话", + "大", + "大不了", + "大大", + "大抵", + "大都", + "大多", + "大凡", + "大概", + "大家", + "大举", + "大略", + "大面儿上", + "大事", + "大体", + "大体上", + "大约", + "大张旗鼓", + "大致", + "呆呆地", + "带", + "殆", + "待到", + "单", + "单纯", + "单单", + "但愿", + "弹指之间", + "当场", + "当儿", + "当即", + "当口儿", + "当然", + "当庭", + "当头", + "当下", + "当真", + "当中", + "倒不如", + "倒不如说", + "倒是", + "到处", + "到底", + "到了儿", + "到目前为止", + "到头", + "到头来", + "得起", + "得天独厚", + "的确", + "等到", + "叮当", + "顶多", + "定", + "动不动", + "动辄", + "陡然", + "都", + "独", + "独自", + "断然", + "顿时", + "多次", + "多多", + "多多少少", + "多多益善", + "多亏", + "多年来", + "多年前", + "而后", + "而论", + "而又", + "尔等", + "二话不说", + "二话没说", + "反倒", + "反倒是", + "反而", + "反手", + "反之亦然", + "反之则", + "方", + "方才", + "方能", + "放量", + "非常", + "非得", + "分期", + "分期分批", + "分头", + "奋勇", + "愤然", + "风雨无阻", + "逢", + "弗", + "甫", + "嘎嘎", + "该当", + "概", + "赶快", + "赶早不赶晚", + "敢", + "敢情", + "敢于", + "刚", + "刚才", + "刚好", + "刚巧", + "高低", + "格外", + "隔日", + "隔夜", + "个人", + "各式", + "更", + "更加", + "更进一步", + "更为", + "公然", + "共", + "共总", + "够瞧的", + "姑且", + "古来", + "故而", + "故意", + "固", + "怪", + "怪不得", + "惯常", + "光", + "光是", + "归根到底", + "归根结底", + "过于", + "毫不", + "毫无", + "毫无保留地", + "毫无例外", + "好在", + "何必", + "何尝", + "何妨", + "何苦", + "何乐而不为", + "何须", + "何止", + "很", + "很多", + "很少", + "轰然", + "后来", + "呼啦", + "忽地", + "忽然", + "互", + "互相", + "哗啦", + "话说", + "还", + "恍然", + "会", + "豁然", + "活", + "伙同", + "或多或少", + "或许", + "基本", + "基本上", + "基于", + "极", + "极大", + "极度", + "极端", + "极力", + "极其", + "极为", + "急匆匆", + "即将", + "即刻", + "即是说", + "几度", + "几番", + "几乎", + "几经", + "既...又", + "继之", + "加上", + "加以", + "间或", + "简而言之", + "简言之", + "简直", + "见", + "将才", + "将近", + "将要", + "交口", + "较比", + "较为", + "接连不断", + "接下来", + "皆可", + "截然", + "截至", + "藉以", + "借此", + "借以", + "届时", + "仅", + "仅仅", + "谨", + "进来", + "进去", + "近", + "近几年来", + "近来", + "近年来", + "尽管如此", + "尽可能", + "尽快", + "尽量", + "尽然", + "尽如人意", + "尽心竭力", + "尽心尽力", + "尽早", + "精光", + "经常", + "竟", + "竟然", + "究竟", + "就此", + "就地", + "就算", + "居然", + "局外", + "举凡", + "据称", + "据此", + "据实", + "据说", + "据我所知", + "据悉", + "具体来说", + "决不", + "决非", + "绝", + "绝不", + "绝顶", + "绝对", + "绝非", + "均", + "喀", + "看", + "看来", + "看起来", + "看上去", + "看样子", + "可好", + "可能", + "恐怕", + "快", + "快要", + "来不及", + "来得及", + "来讲", + "来看", + "拦腰", + "牢牢", + "老", + "老大", + "老老实实", + "老是", + "累次", + "累年", + "理当", + "理该", + "理应", + "历", + "立", + "立地", + "立刻", + "立马", + "立时", + "联袂", + "连连", + "连日", + "连日来", + "连声", + "连袂", + "临到", + "另方面", + "另行", + "另一个", + "路经", + "屡", + "屡次", + "屡次三番", + "屡屡", + "缕缕", + "率尔", + "率然", + "略", + "略加", + "略微", + "略为", + "论说", + "马上", + "蛮", + "满", + "没", + "没有", + "每逢", + "每每", + "每时每刻", + "猛然", + "猛然间", + "莫", + "莫不", + "莫非", + "莫如", + "默默地", + "默然", + "呐", + "那末", + "奈", + "难道", + "难得", + "难怪", + "难说", + "内", + "年复一年", + "凝神", + "偶而", + "偶尔", + "怕", + "砰", + "碰巧", + "譬如", + "偏偏", + "乒", + "平素", + "颇", + "迫于", + "扑通", + "其后", + "其实", + "奇", + "齐", + "起初", + "起来", + "起首", + "起头", + "起先", + "岂", + "岂非", + "岂止", + "迄", + "恰逢", + "恰好", + "恰恰", + "恰巧", + "恰如", + "恰似", + "千", + "千万", + "千万千万", + "切", + "切不可", + "切莫", + "切切", + "切勿", + "窃", + "亲口", + "亲身", + "亲手", + "亲眼", + "亲自", + "顷", + "顷刻", + "顷刻间", + "顷刻之间", + "请勿", + "穷年累月", + "取道", + "去", + "权时", + "全都", + "全力", + "全年", + "全然", + "全身心", + "然", + "人人", + "仍", + "仍旧", + "仍然", + "日复一日", + "日见", + "日渐", + "日益", + "日臻", + "如常", + "如此等等", + "如次", + "如今", + "如期", + "如前所述", + "如上", + "如下", + "汝", + "三番两次", + "三番五次", + "三天两头", + "瑟瑟", + "沙沙", + "上", + "上来", + "上去", + "一个", + "月", + "日", + "\n", } diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index b77c6562b2..4b9ec460e6 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -8,7 +8,6 @@ from models.dataset import Dataset class BaseKeyword(ABC): - def __init__(self, dataset: Dataset): self.dataset = dataset @@ -31,15 +30,12 @@ class BaseKeyword(ABC): def delete(self) -> None: raise NotImplementedError - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: - doc_id = text.metadata['doc_id'] + for text in texts.copy(): + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -47,4 +43,4 @@ class BaseKeyword(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 6ac610f82b..3c99f33be6 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -20,9 +20,7 @@ class Keyword: raise ValueError("Keyword store must be specified.") if keyword_type == "jieba": - return Jieba( - dataset=self._dataset - ) + return Jieba(dataset=self._dataset) else: raise ValueError(f"Keyword store {keyword_type} is not supported.") @@ -41,10 +39,7 @@ class Keyword: def delete(self) -> None: self._keyword_processor.delete() - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: return self._keyword_processor.search(query, **kwargs) def __getattr__(self, name): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0dac9bfae6..afac1bf300 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -12,73 +12,83 @@ from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } class RetrievalService: - @classmethod - def retrieve(cls, retrieval_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def retrieve( + cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = 0.0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = "reranking_model", + weights: Optional[dict] = None, + ): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword - if retrieval_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if retrieval_method == "keyword_search": + keyword_thread = threading.Thread( + target=RetrievalService.keyword_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(keyword_thread) keyword_thread.start() # retrieval_model source with semantic if RetrievalMethod.is_support_semantic_search(retrieval_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrieval_method': retrieval_method, - 'exceptions': exceptions, - }) + embedding_thread = threading.Thread( + target=RetrievalService.embedding_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "score_threshold": score_threshold, + "reranking_model": reranking_model, + "all_documents": all_documents, + "retrieval_method": retrieval_method, + "exceptions": exceptions, + }, + ) threads.append(embedding_thread) embedding_thread.start() # retrieval source with full text if RetrievalMethod.is_support_fulltext_search(retrieval_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrieval_method': retrieval_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + full_text_index_thread = threading.Thread( + target=RetrievalService.full_text_index_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "retrieval_method": retrieval_method, + "score_threshold": score_threshold, + "top_k": top_k, + "reranking_model": reranking_model, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -86,110 +96,117 @@ class RetrievalService: thread.join() if exceptions: - exception_message = ';\n'.join(exceptions) + exception_message = ";\n".join(exceptions) raise Exception(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) return all_documents @classmethod - def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list, exceptions: list): + def keyword_search( + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword(dataset=dataset) - documents = keyword.search( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrieval_method: str, exceptions: list): + def embedding_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) documents = vector.search_by_vector( cls.escape_query_for_search(query), - search_type='similarity_score_threshold', + search_type="similarity_score_threshold", top_k=top_k, score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + filter={"group_id": [dataset.id]}, ) if documents: - if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrieval_method: str, exceptions: list): + def full_text_index_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() vector_processor = Vector( dataset=dataset, ) - documents = vector_processor.search_by_full_text( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: - if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: @@ -197,4 +214,4 @@ class RetrievalService: @staticmethod def escape_query_for_search(query: str) -> str: - return query.replace('"', '\\"') \ No newline at end of file + return query.replace('"', '\\"') diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index b78e2a59b1..612542dab1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel): namespace_password: str = (None,) metrics: str = ("cosine",) read_timeout: int = 60000 + def to_analyticdb_client_params(self): return { "access_key_id": self.access_key_id, @@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel): "read_timeout": self.read_timeout, } + class AnalyticdbVector(BaseVector): _instance = None _init = False @@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector): except: raise ImportError(_import_err_msg) self.config = config - self._client_config = open_api_models.Config( - user_agent="dify", **config.to_analyticdb_client_params() - ) + self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) self._client = Client(self._client_config) self._initialize() AnalyticdbVector._init = True @@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector): def _initialize_vector_database(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector): def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + try: request = gpdb_20160503_models.DescribeNamespaceRequest( dbinstance_id=self.config.instance_id, @@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector): ) self._client.create_namespace(request) else: - raise ValueError( - f"failed to create namespace {self.config.namespace}: {e}" - ) + raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector): ) self._client.create_collection(request) else: - raise ValueError( - f"failed to create collection {self._collection_name}: {e}" - ) + raise ValueError(f"failed to create collection {self._collection_name}: {e}") redis_client.set(collection_exist_cache_key, 1, ex=3600) def get_type(self) -> str: @@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector): self._create_collection_if_not_exists(dimension) self.add_texts(texts, embeddings) - def add_texts( - self, documents: list[Document], embeddings: list[list[float]], **kwargs - ): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): metadata = { @@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector): def text_exists(self, id: str) -> bool: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector): vector=None, content=None, top_k=1, - filter=f"ref_doc_id='{id}'" + filter=f"ref_doc_id='{id}'", ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + ids_str = ",".join(f"'{id}'" for id in ids) ids_str = f"({ids_str})" request = gpdb_20160503_models.DeleteCollectionDataRequest( @@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector): ) self._client.delete_collection_data(request) - def search_by_vector( - self, query_vector: list[float], **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector): def delete(self) -> None: try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionRequest( collection=self._collection_name, dbinstance_id=self.config.instance_id, @@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector): except Exception as e: raise e + class AnalyticdbVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"][ - "class_prefix" - ] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) - ) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) # handle optional params if dify_config.ANALYTICDB_KEY_ID is None: diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 3629887b44..610aa498ab 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -27,21 +27,20 @@ class ChromaConfig(BaseModel): settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, - chroma_client_auth_credentials=self.auth_credentials + chroma_client_auth_credentials=self.auth_credentials, ) return { - 'host': self.host, - 'port': self.port, - 'ssl': False, - 'tenant': self.tenant, - 'database': self.database, - 'settings': settings, + "host": self.host, + "port": self.port, + "ssl": False, + "tenant": self.tenant, + "database": self.database, + "settings": settings, } class ChromaVector(BaseVector): - def __init__(self, collection_name: str, config: ChromaConfig): super().__init__(collection_name) self._client_config = config @@ -58,9 +57,9 @@ class ChromaVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return self._client.get_or_create_collection(collection_name) @@ -76,7 +75,7 @@ class ChromaVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {'$eq': value}}) + collection.delete(where={key: {"$eq": value}}) def delete(self): self._client.delete_collection(self._collection_name) @@ -93,26 +92,26 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results['ids'][0] - documents: list[str] = results['documents'][0] - metadatas: dict[str, Any] = results['metadatas'][0] - distances: list[float] = results['distances'][0] + ids: list[str] = results["ids"][0] + documents: list[str] = results["documents"][0] + metadatas: dict[str, Any] = results["metadatas"][0] + distances: list[float] = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] metadata = metadatas[index] if distance >= score_threshold: - metadata['score'] = distance + metadata["score"] = distance doc = Document( page_content=documents[index], metadata=metadata, ) docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -123,15 +122,12 @@ class ChromaVector(BaseVector): class ChromaVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": VectorType.CHROMA, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) return ChromaVector( diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 233539756f..8d57855120 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -26,15 +26,16 @@ class ElasticSearchConfig(BaseModel): username: str password: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PORT is required") - if not values['username']: + if not values["username"]: raise ValueError("config USERNAME is required") - if not values['password']: + if not values["password"]: raise ValueError("config PASSWORD is required") return values @@ -50,10 +51,10 @@ class ElasticSearchVector(BaseVector): def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ['http', 'https']: - hosts = f'{config.host}:{config.port}' + if parsed_url.scheme in {"http", "https"}: + hosts = f"{config.host}:{config.port}" else: - hosts = f'http://{config.host}:{config.port}' + hosts = f"http://{config.host}:{config.port}" client = Elasticsearch( hosts=hosts, basic_auth=(config.username, config.password), @@ -68,45 +69,41 @@ class ElasticSearchVector(BaseVector): def _get_version(self) -> str: info = self._client.info() - return info['version']['number'] + return info["version"]["number"] def _check_version(self): - if self._version < '8.0.0': + if self._version < "8.0.0": raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: - return 'elasticsearch' + return "elasticsearch" def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) for i in range(len(documents)): - self._client.index(index=self._collection_name, - id=uuids[i], - document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] if embeddings[i] else None, - Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} - }) + self._client.index( + index=self._collection_name, + id=uuids[i], + document={ + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] or None, + Field.METADATA_KEY.value: documents[i].metadata or {}, + }, + ) self._client.indices.refresh(index=self._collection_name) return uuids def text_exists(self, id: str) -> bool: - return self._client.exists(index=self._collection_name, id=id).__bool__() + return bool(self._client.exists(index=self._collection_name, id=id)) def delete_by_ids(self, ids: list[str]) -> None: for id in ids: self._client.delete(index=self._collection_name, id=id) def delete_by_metadata_field(self, key: str, value: str) -> None: - query_str = { - 'query': { - 'match': { - f'metadata.{key}': f'{value}' - } - } - } + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit['_id'] for hit in results['hits']['hits']] + ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) @@ -115,44 +112,44 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 10) - knn = { - "field": Field.VECTOR.value, - "query_vector": query_vector, - "k": top_k - } + knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k} results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] - for hit in results['hits']['hits']: + for hit in results["hits"]["hits"]: docs_and_scores.append( - (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = { - "match": { - Field.CONTENT_KEY.value: query - } - } + query_str = {"match": {Field.CONTENT_KEY.value: query}} results = self._client.search(index=self._collection_name, query=query_str) docs = [] - for hit in results['hits']['hits']: - docs.append(Document( - page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value], - )) + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) return docs @@ -162,11 +159,11 @@ class ElasticSearchVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name}' + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name} already exists.") return @@ -179,14 +176,14 @@ class ElasticSearchVector(BaseVector): Field.VECTOR.value: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, - "similarity": "cosine" + "similarity": "cosine", }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } } self._client.indices.create(index=self._collection_name, mappings=mappings) @@ -197,22 +194,21 @@ class ElasticSearchVector(BaseVector): class ElasticSearchVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get('ELASTICSEARCH_HOST'), - port=config.get('ELASTICSEARCH_PORT'), - username=config.get('ELASTICSEARCH_USERNAME'), - password=config.get('ELASTICSEARCH_PASSWORD'), + host=config.get("ELASTICSEARCH_HOST"), + port=config.get("ELASTICSEARCH_PORT"), + username=config.get("ELASTICSEARCH_USERNAME"), + password=config.get("ELASTICSEARCH_PASSWORD"), ), - attributes=[] + attributes=[], ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index c1c73d1c0d..bdca59f869 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -27,44 +27,40 @@ class MilvusConfig(BaseModel): batch_size: int = 100 database: str = "default" - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values.get('uri'): + if not values.get("uri"): raise ValueError("config MILVUS_URI is required") - if not values.get('user'): + if not values.get("user"): raise ValueError("config MILVUS_USER is required") - if not values.get('password'): + if not values.get("password"): raise ValueError("config MILVUS_PASSWORD is required") return values def to_milvus_params(self): return { - 'uri': self.uri, - 'token': self.token, - 'user': self.user, - 'password': self.password, - 'db_name': self.database, + "uri": self.uri, + "token": self.token, + "user": self.user, + "password": self.password, + "db_name": self.database, } class MilvusVector(BaseVector): - def __init__(self, collection_name: str, config: MilvusConfig): super().__init__(collection_name) self._client_config = config self._client = self._init_client(config) - self._consistency_level = 'Session' + self._consistency_level = "Session" self._fields = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } + index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} metadatas = [d.metadata for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -75,7 +71,7 @@ class MilvusVector(BaseVector): insert_dict = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata + Field.METADATA_KEY.value: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -84,22 +80,20 @@ class MilvusVector(BaseVector): pks: list[str] = [] for i in range(0, total_count, 1000): - batch_insert_list = insert_dict_list[i:i + 1000] + batch_insert_list = insert_dict_list[i : i + 1000] # Insert into the collection. try: ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) + logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) raise e return pks def get_ids_by_metadata_field(self, key: str, value: str): - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["{key}"] == "{value}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] + ) if result: return [item["id"] for item in result] else: @@ -107,17 +101,15 @@ class MilvusVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str): if self._client.has_collection(self._collection_name): - ids = self.get_ids_by_metadata_field(key, value) if ids: self._client.delete(collection_name=self._collection_name, pks=ids) def delete_by_ids(self, ids: list[str]) -> None: if self._client.has_collection(self._collection_name): - - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] in {ids}', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] + ) if result: ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) @@ -130,29 +122,28 @@ class MilvusVector(BaseVector): if not self._client.has_collection(self._collection_name): return False - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] == "{id}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"] + ) return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Set search parameters. - results = self._client.search(collection_name=self._collection_name, - data=[query_vector], - limit=kwargs.get('top_k', 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], - ) + results = self._client.search( + collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get("top_k", 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) # Organize results. docs = [] for result in results[0]: - metadata = result['entity'].get(Field.METADATA_KEY.value) - metadata['score'] = result['distance'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if result['distance'] > score_threshold: - doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), - metadata=metadata) + metadata = result["entity"].get(Field.METADATA_KEY.value) + metadata["score"] = result["distance"] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if result["distance"] > score_threshold: + doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -161,11 +152,11 @@ class MilvusVector(BaseVector): return [] def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return # Grab the existing collection if it exists @@ -180,19 +171,11 @@ class MilvusVector(BaseVector): fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) + fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create the schema for the collection schema = CollectionSchema(fields) @@ -208,9 +191,12 @@ class MilvusVector(BaseVector): # Create the collection collection_name = self._collection_name - self._client.create_collection(collection_name=collection_name, - schema=schema, index_params=index_params_obj, - consistency_level=self._consistency_level) + self._client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params_obj, + consistency_level=self._consistency_level, + ) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: @@ -221,13 +207,12 @@ class MilvusVector(BaseVector): class MilvusVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) return MilvusVector( collection_name=collection_name, @@ -237,5 +222,5 @@ class MilvusVectorFactory(AbstractVectorFactory): user=dify_config.MILVUS_USER, password=dify_config.MILVUS_PASSWORD, database=dify_config.MILVUS_DATABASE, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 05e75effef..2320a69a30 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -31,12 +31,11 @@ class SortOrder(Enum): class MyScaleVector(BaseVector): - def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"): super().__init__(collection_name) self._config = config self._metric = metric - self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC + self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC self._client = get_client( host=config.host, port=config.port, @@ -80,7 +79,7 @@ class MyScaleVector(BaseVector): doc_id, self.escape_str(doc.page_content), embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {} + json.dumps(doc.metadata) if doc.metadata else {}, ) values.append(str(row)) ids.append(doc_id) @@ -93,7 +92,7 @@ class MyScaleVector(BaseVector): @staticmethod def escape_str(value: Any) -> str: - return "".join(" " if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in {"\\", "'"} else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") @@ -101,7 +100,8 @@ class MyScaleVector(BaseVector): def delete_by_ids(self, ids: list[str]) -> None: self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}") + f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" + ) def get_ids_by_metadata_field(self, key: str, value: str): rows = self._client.query( @@ -122,9 +122,12 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get('score_threshold') or 0.0 - where_str = f"WHERE dist < {1 - score_threshold}" if \ - self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" + score_threshold = float(kwargs.get("score_threshold") or 0.0) + where_str = ( + f"WHERE dist < {1 - score_threshold}" + if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 + else "" + ) sql = f""" SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} @@ -133,7 +136,7 @@ class MyScaleVector(BaseVector): return [ Document( page_content=r["text"], - vector=r['vector'], + vector=r["vector"], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() @@ -149,13 +152,12 @@ class MyScaleVector(BaseVector): class MyScaleVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) return MyScaleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index c95d202173..8d2e0a86ab 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -28,11 +28,12 @@ class OpenSearchConfig(BaseModel): password: Optional[str] = None secure: bool = False - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values.get('host'): + if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") - if not values.get('port'): + if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") return values @@ -44,19 +45,18 @@ class OpenSearchConfig(BaseModel): def to_opensearch_params(self) -> dict[str, Any]: params = { - 'hosts': [{'host': self.host, 'port': self.port}], - 'use_ssl': self.secure, - 'verify_certs': self.secure, + "hosts": [{"host": self.host, "port": self.port}], + "use_ssl": self.secure, + "verify_certs": self.secure, } if self.user and self.password: - params['http_auth'] = (self.user, self.password) + params["http_auth"] = (self.user, self.password) if self.secure: - params['ssl_context'] = self.create_ssl_context() + params["ssl_context"] = self.create_ssl_context() return params class OpenSearchVector(BaseVector): - def __init__(self, collection_name: str, config: OpenSearchConfig): super().__init__(collection_name) self._client_config = config @@ -81,7 +81,7 @@ class OpenSearchVector(BaseVector): Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, - } + }, } actions.append(action) @@ -90,8 +90,8 @@ class OpenSearchVector(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) - if response['hits']['hits']: - return [hit['_id'] for hit in response['hits']['hits']] + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] else: return None @@ -110,7 +110,7 @@ class OpenSearchVector(BaseVector): actual_ids = [] for doc_id in ids: - es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) + es_ids = self.get_ids_by_metadata_field("doc_id", doc_id) if es_ids: actual_ids.extend(es_ids) else: @@ -122,9 +122,9 @@ class OpenSearchVector(BaseVector): helpers.bulk(self._client, actions) except BulkIndexError as e: for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') + delete_error = error.get("delete", {}) + status = delete_error.get("status") + doc_id = delete_error.get("_id") if status == 404: logger.warning(f"Document not found for deletion: {doc_id}") @@ -151,15 +151,8 @@ class OpenSearchVector(BaseVector): raise ValueError("All elements in query_vector should be floats") query = { - "size": kwargs.get('top_k', 4), - "query": { - "knn": { - Field.VECTOR.value: { - Field.VECTOR.value: query_vector, - "k": kwargs.get('top_k', 4) - } - } - } + "size": kwargs.get("top_k", 4), + "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, } try: @@ -169,17 +162,17 @@ class OpenSearchVector(BaseVector): raise docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value, {}) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) # Make sure metadata is a dictionary if metadata is None: metadata = {} - metadata['score'] = hit['_score'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if hit['_score'] > score_threshold: - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + metadata["score"] = hit["_score"] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if hit["_score"] > score_threshold: + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -190,32 +183,28 @@ class OpenSearchVector(BaseVector): response = self._client.search(index=self._collection_name.lower(), body=full_text_query) docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value) - vector = hit['_source'].get(Field.VECTOR.value) - page_content = hit['_source'].get(Field.CONTENT_KEY.value) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value) + vector = hit["_source"].get(Field.VECTOR.value) + page_content = hit["_source"].get(Field.CONTENT_KEY.value) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name.lower()}' + lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name.lower()} already exists.") return if not self._client.indices.exists(index=self._collection_name.lower()): index_body = { - "settings": { - "index": { - "knn": True - } - }, + "settings": {"index": {"knn": True}}, "mappings": { "properties": { Field.CONTENT_KEY.value: {"type": "text"}, @@ -226,20 +215,17 @@ class OpenSearchVector(BaseVector): "name": "hnsw", "space_type": "l2", "engine": "faiss", - "parameters": { - "ef_construction": 64, - "m": 8 - } - } + "parameters": {"ef_construction": 64, "m": 8}, + }, }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } - } + }, } self._client.indices.create(index=self._collection_name.lower(), body=index_body) @@ -248,17 +234,14 @@ class OpenSearchVector(BaseVector): class OpenSearchVectorFactory(AbstractVectorFactory): - def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST, @@ -268,7 +251,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory): secure=dify_config.OPENSEARCH_SECURE, ) - return OpenSearchVector( - collection_name=collection_name, - config=open_search_config - ) + return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index aa2c6171c3..77ec45b4d3 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -31,7 +31,8 @@ class OracleVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") @@ -103,9 +104,16 @@ class OracleVector(BaseVector): arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - def _create_connection_pool(self, config: OracleVectorConfig): - return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1) + def _create_connection_pool(self, config: OracleVectorConfig): + return oracledb.create_pool( + user=config.user, + password=config.password, + dsn="{}:{}/{}".format(config.host, config.port, config.database), + min=1, + max=50, + increment=1, + ) @contextmanager def _get_cursor(self): @@ -136,13 +144,15 @@ class OracleVector(BaseVector): doc_id, doc.page_content, json.dumps(doc.metadata), - #array.array("f", embeddings[i]), + # array.array("f", embeddings[i]), numpy.array(embeddings[i]), ) ) - #print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") + # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: - cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values) + cur.executemany( + f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + ) return pks def text_exists(self, id: str) -> bool: @@ -157,7 +167,8 @@ class OracleVector(BaseVector): for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) return docs - #def get_ids_by_metadata_field(self, key: str, value: str): + + # def get_ids_by_metadata_field(self, key: str, value: str): # with self._get_cursor() as cur: # cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" ) # idss = [] @@ -184,10 +195,12 @@ class OracleVector(BaseVector): top_k = kwargs.get("top_k", 5) with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)] + f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" + f" ORDER BY distance fetch first {top_k} rows only", + [numpy.array(query_vector)], ) docs = [] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in cur: metadata, text, distance = record score = 1 - distance @@ -199,10 +212,10 @@ class OracleVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: # Check which language the query is in - zh_pattern = re.compile('[\u4e00-\u9fa5]+') + zh_pattern = re.compile("[\u4e00-\u9fa5]+") match = zh_pattern.search(query) entities = [] # match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split @@ -210,7 +223,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -220,22 +233,23 @@ class OracleVector(BaseVector): entities.append(current_entity) else: try: - nltk.data.find('tokenizers/punkt') - nltk.data.find('corpora/stopwords') + nltk.data.find("tokenizers/punkt") + nltk.data.find("corpora/stopwords") except LookupError: - nltk.download('punkt') - nltk.download('stopwords') + nltk.download("punkt") + nltk.download("stopwords") print("run download") - e_str = re.sub(r'[^\w ]', '', query) + e_str = re.sub(r"[^\w ]", "", query) all_tokens = nltk.word_tokenize(e_str) - stop_words = stopwords.words('english') + stop_words = stopwords.words("english") for token in all_tokens: if token not in stop_words: entities.append(token) with self._get_cursor() as cur: cur.execute( - f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)] + f"select meta, text, embedding FROM {self.table_name}" + f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", + [" ACCUM ".join(entities)], ) docs = [] for record in cur: @@ -273,8 +287,7 @@ class OracleVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) return OracleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a48224070f..de2d65b223 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -31,27 +31,29 @@ class PgvectoRSConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PGVECTO_RS_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config PGVECTO_RS_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config PGVECTO_RS_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config PGVECTO_RS_DATABASE is required") return values class PGVectoRS(BaseVector): - def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): super().__init__(collection_name) self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self._client = create_engine(self._url) with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) @@ -80,9 +82,9 @@ class PGVectoRS(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -133,9 +135,7 @@ class PGVectoRS(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: - select_statement = sql_text( - f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; " - ) + select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ") result = session.execute(select_statement).fetchall() if result: return [item[0] for item in result] @@ -143,12 +143,11 @@ class PGVectoRS(BaseVector): return None def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete_by_ids(self, ids: list[str]) -> None: @@ -156,13 +155,13 @@ class PGVectoRS(BaseVector): select_statement = sql_text( f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " ) - result = session.execute(select_statement, {'doc_ids': ids}).fetchall() + result = session.execute(select_statement, {"doc_ids": ids}).fetchall() if result: ids = [item[0] for item in result] if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete(self) -> None: @@ -187,7 +186,7 @@ class PGVectoRS(BaseVector): query_vector, ).label("distance"), ) - .limit(kwargs.get('top_k', 2)) + .limit(kwargs.get("top_k", 2)) .order_by("distance") ) res = session.execute(stmt) @@ -198,11 +197,10 @@ class PGVectoRS(BaseVector): for record, dis in results: metadata = record.meta score = 1 - dis - metadata['score'] = score - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + metadata["score"] = score + score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc = Document(page_content=record.text, - metadata=metadata) + doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) return docs @@ -225,13 +223,12 @@ class PGVectoRS(BaseVector): class PGVectoRSFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( @@ -243,5 +240,5 @@ class PGVectoRSFactory(AbstractVectorFactory): password=dify_config.PGVECTO_RS_PASSWORD, database=dify_config.PGVECTO_RS_DATABASE, ), - dim=dim + dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index c9f2f35af0..79879d4f63 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -24,7 +24,8 @@ class PGVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") @@ -138,11 +139,12 @@ class PGVector(BaseVector): with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}", + f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" + f" ORDER BY distance LIMIT {top_k}", (json.dumps(query_vector),), ) docs = [] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in cur: metadata, text, distance = record score = 1 - distance @@ -201,8 +203,7 @@ class PGVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) return PGVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 297bff928e..f418e3ca05 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -48,28 +48,25 @@ class QdrantConfig(BaseModel): prefer_grpc: bool = False def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") if not os.path.isabs(path): path = os.path.join(self.root_path, path) - return { - 'path': path - } + return {"path": path} else: return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout, - 'verify': self.endpoint.startswith('https'), - 'grpc_port': self.grpc_port, - 'prefer_grpc': self.prefer_grpc + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": self.endpoint.startswith("https"), + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, } class QdrantVector(BaseVector): - - def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) @@ -80,10 +77,7 @@ class QdrantVector(BaseVector): return VectorType.QDRANT def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -97,9 +91,9 @@ class QdrantVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex @@ -110,12 +104,19 @@ class QdrantVector(BaseVector): all_collection_name.append(collection.name) if collection_name not in all_collection_name: from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[self._distance_func], ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) self._client.recreate_collection( collection_name=collection_name, vectors_config=vectors_config, @@ -124,21 +125,24 @@ class QdrantVector(BaseVector): ) # create group_id payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) # create doc_id payload index - self._client.create_payload_index(collection_name, Field.DOC_ID.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) # create full text index text_index_params = TextIndexParams( type=TextIndexType.TEXT, tokenizer=TokenizerType.MULTILINGUAL, min_token_len=2, max_token_len=20, - lowercase=True + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -147,26 +151,23 @@ class QdrantVector(BaseVector): metadatas = [d.metadata for d in documents] added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, embeddings, metadatas, uuids, 64, self._group_id - ): - self._client.upsert( - collection_name=self._collection_name, points=points - ) + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) return added_ids def _generate_rest_batches( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest + texts_iterator = iter(texts) embeddings_iterator = iter(embeddings) metadatas_iterator = iter(metadatas or []) @@ -203,13 +204,13 @@ class QdrantVector(BaseVector): @classmethod def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, ) -> list[dict]: payloads = [] for i, text in enumerate(texts): @@ -219,18 +220,11 @@ class QdrantVector(BaseVector): "calling .from_texts or .add_texts on Qdrant instance." ) metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) return payloads def delete_by_metadata_field(self, key: str, value: str): - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -248,9 +242,7 @@ class QdrantVector(BaseVector): self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -275,9 +267,7 @@ class QdrantVector(BaseVector): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -288,7 +278,6 @@ class QdrantVector(BaseVector): raise e def delete_by_ids(self, ids: list[str]) -> None: - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -304,9 +293,7 @@ class QdrantVector(BaseVector): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -324,15 +311,13 @@ class QdrantVector(BaseVector): all_collection_name.append(collection.name) if self._collection_name not in all_collection_name: return False - response = self._client.retrieve( - collection_name=self._collection_name, - ids=[id] - ) + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) return len(response) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + filter = models.Filter( must=[ models.FieldCondition( @@ -348,22 +333,22 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=kwargs.get("score_threshold", .0) + score_threshold=float(kwargs.get("score_threshold") or 0.0), ) docs = [] for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: - metadata['score'] = result.score + metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -372,6 +357,7 @@ class QdrantVector(BaseVector): List of documents most similar to the query text and distance for each. """ from qdrant_client.http import models + scroll_filter = models.Filter( must=[ models.FieldCondition( @@ -381,24 +367,21 @@ class QdrantVector(BaseVector): models.FieldCondition( key="page_content", match=models.MatchText(text=query), - ) + ), ] ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, - limit=kwargs.get('top_k', 2), + limit=kwargs.get("top_k", 2), with_payload=True, - with_vectors=True - + with_vectors=True, ) results = response[0] documents = [] for result in results: if result: - document = self._document_from_scored_point( - result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value - ) + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) documents.append(document) return documents @@ -410,10 +393,10 @@ class QdrantVector(BaseVector): @classmethod def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), @@ -425,24 +408,25 @@ class QdrantVector(BaseVector): class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) if not dataset.index_struct_dict: - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) config = current_app.config return QdrantVector( @@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory): root_path=config.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED - ) + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ), ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 63ad0682d7..f47f75718a 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -33,28 +33,30 @@ class RelytConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config RELYT_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config RELYT_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config RELYT_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config RELYT_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config RELYT_DATABASE is required") return values class RelytVector(BaseVector): - def __init__(self, collection_name: str, config: RelytConfig, group_id: str): super().__init__(collection_name) self.embedding_dimension = 1536 self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self.client = create_engine(self._url) self._fields = [] self._group_id = group_id @@ -70,9 +72,9 @@ class RelytVector(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -110,7 +112,7 @@ class RelytVector(BaseVector): ids = [str(uuid.uuid1()) for _ in documents] metadatas = [d.metadata for d in documents] for metadata in metadatas: - metadata['group_id'] = self._group_id + metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] # Define the table schema @@ -125,29 +127,26 @@ class RelytVector(BaseVector): ) chunks_table_data = [] - with self.client.connect() as conn: - with conn.begin(): - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding, - "document": document, - "metadata": metadata, - } - ) + with self.client.connect() as conn, conn.begin(): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(chunks_table).values(chunks_table_data)) return ids @@ -186,25 +185,22 @@ class RelytVector(BaseVector): ) try: - with self.client.connect() as conn: - with conn.begin(): - delete_condition = chunks_table.c.id.in_(ids) - conn.execute(chunks_table.delete().where(delete_condition)) - return True + with self.client.connect() as conn, conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + conn.execute(chunks_table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_uuids(ids) def delete_by_ids(self, ids: list[str]) -> None: - with Session(self.client) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) @@ -228,38 +224,34 @@ class RelytVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get('top_k')), - embedding=query_vector, - filter=kwargs.get('filter') + k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter") ) # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if 1 - score > score_threshold: docs.append(document) return docs def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided try: from sqlalchemy.engine import Row except ImportError: - raise ImportError( - "Could not import Row from sqlalchemy.engine. " - "Please 'pip install sqlalchemy>=1.4'." - ) + raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: conditions = [ - f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1 + f"metadata->>{key!r} in ({', '.join(map(repr, value))})" + if len(value) > 1 else f"metadata->>{key!r} = {value[0]!r}" for key, value in filter.items() ] @@ -305,13 +297,12 @@ class RelytVector(BaseVector): class RelytVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.RELYT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name)) return RelytVector( collection_name=collection_name, @@ -322,5 +313,5 @@ class RelytVectorFactory(AbstractVectorFactory): password=dify_config.RELYT_PASSWORD, database=dify_config.RELYT_DATABASE, ), - group_id=dataset.id + group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3325a1028e..faa373017b 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -25,16 +25,11 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = 1, - replicas: int = 2, + shard: int = (1,) + replicas: int = (2,) def to_tencent_params(self): - return { - 'url': self.url, - 'username': self.username, - 'key': self.api_key, - 'timeout': self.timeout - } + return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} class TencentVector(BaseVector): @@ -61,25 +56,19 @@ class TencentVector(BaseVector): return self._client.create_database(database_name=self._client_config.database) def get_type(self) -> str: - return 'tencent' + return "tencent" def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def _has_collection(self) -> bool: collections = self._db.list_collections() - for collection in collections: - if collection.collection_name == self._collection_name: - return True - return False + return any(collection.collection_name == self._collection_name for collection in collections) def _create_collection(self, dimension: int) -> None: - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return @@ -101,9 +90,7 @@ class TencentVector(BaseVector): raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams(m=16, efconstruction=200) index = vdb_index.Index( - vdb_index.FilterIndex( - self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY - ), + vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), vdb_index.VectorIndex( self.field_vector, dimension, @@ -111,12 +98,8 @@ class TencentVector(BaseVector): metric_type, params, ), - vdb_index.FilterIndex( - self.field_text, enum.FieldType.String, enum.IndexType.FILTER - ), - vdb_index.FilterIndex( - self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER - ), + vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), + vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), ) self._db.create_collection( @@ -163,15 +146,14 @@ class TencentVector(BaseVector): self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - - res = self._db.collection(self._collection_name).search(vectors=[query_vector], - params=document.HNSWSearchParams( - ef=kwargs.get("ef", 10)), - retrieve_vector=False, - limit=kwargs.get('top_k', 4), - timeout=self._client_config.timeout, - ) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + res = self._db.collection(self._collection_name).search( + vectors=[query_vector], + params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), + retrieve_vector=False, + limit=kwargs.get("top_k", 4), + timeout=self._client_config.timeout, + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -200,15 +182,13 @@ class TencentVector(BaseVector): class TencentVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) return TencentVector( collection_name=collection_name, @@ -220,5 +200,5 @@ class TencentVectorFactory(AbstractVectorFactory): database=dify_config.TENCENT_VECTOR_DB_DATABASE, shard=dify_config.TENCENT_VECTOR_DB_SHARD, replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index d3685c0991..20490d3215 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -28,47 +28,57 @@ class TiDBVectorConfig(BaseModel): database: str program_name: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config TIDB_VECTOR_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config TIDB_VECTOR_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config TIDB_VECTOR_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config TIDB_VECTOR_DATABASE is required") - if not values['program_name']: + if not values["program_name"]: raise ValueError("config APPLICATION_NAME is required") return values class TiDBVector(BaseVector): - def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: from tidb_vector.sqlalchemy import VectorType + return Table( self._collection_name, self._orm_base.metadata, - Column('id', String(36), primary_key=True, nullable=False), - Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("id", String(36), primary_key=True, nullable=False), + Column( + "vector", + VectorType(dim), + nullable=False, + comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})", + ), Column("text", TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), - Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), - extend_existing=True + Column( + "update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), + extend_existing=True, ) - def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"): super().__init__(collection_name) self._client_config = config - self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" - f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}") + self._url = ( + f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}" + ) self._distance_func = distance_func.lower() self._engine = create_engine(self._url) self._orm_base = declarative_base() @@ -83,9 +93,9 @@ class TiDBVector(BaseVector): def _create_collection(self, dimension: int): logger.info("_create_collection, collection_name " + self._collection_name) - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return with Session(self._engine) as session: @@ -114,31 +124,28 @@ class TiDBVector(BaseVector): texts = [d.page_content for d in documents] chunks_table_data = [] - with self._engine.connect() as conn: - with conn.begin(): - for id, text, meta, embedding in zip( - ids, texts, metas, embeddings - ): - chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) + with self._engine.connect() as conn, conn.begin(): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): + chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(table).values(chunks_table_data)) return ids def text_exists(self, id: str) -> bool: - result = self.get_ids_by_metadata_field('doc_id', id) + result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) def delete_by_ids(self, ids: list[str]) -> None: with Session(self._engine) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ ) @@ -152,11 +159,10 @@ class TiDBVector(BaseVector): raise ValueError("No ids provided to delete.") table = self._table(self._dimension) try: - with self._engine.connect() as conn: - with conn.begin(): - delete_condition = table.c.id.in_(ids) - conn.execute(table.delete().where(delete_condition)) - return True + with self._engine.connect() as conn, conn.begin(): + delete_condition = table.c.id.in_(ids) + conn.execute(table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False @@ -179,21 +185,23 @@ class TiDBVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 - filter = kwargs.get('filter') + score_threshold = float(kwargs.get("score_threshold") or 0.0) + filter = kwargs.get("filter") distance = 1 - score_threshold query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = "[" + query_vector_str + "]" - logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + logger.debug( + f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" + ) docs = [] - if self._distance_func == 'l2': - tidb_func = 'Vec_l2_distance' - elif self._distance_func == 'cosine': - tidb_func = 'Vec_Cosine_distance' + if self._distance_func == "l2": + tidb_func = "Vec_l2_distance" + elif self._distance_func == "cosine": + tidb_func = "Vec_Cosine_distance" else: - tidb_func = 'Vec_Cosine_distance' + tidb_func = "Vec_Cosine_distance" with Session(self._engine) as session: select_statement = sql_text( @@ -208,7 +216,7 @@ class TiDBVector(BaseVector): results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: metadata = json.loads(meta) - metadata['score'] = 1 - distance + metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -224,15 +232,13 @@ class TiDBVector(BaseVector): class TiDBVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) return TiDBVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 3f70e8b608..1a0dc7f48b 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -7,7 +7,6 @@ from core.rag.models.document import Document class BaseVector(ABC): - def __init__(self, collection_name: str): self._collection_name = collection_name @@ -39,26 +38,19 @@ class BaseVector(ABC): raise NotImplementedError @abstractmethod - def search_by_vector( - self, - query_vector: list[float], - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: raise NotImplementedError @abstractmethod - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def delete(self) -> None: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: - doc_id = text.metadata['doc_id'] + for text in texts.copy(): + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -66,7 +58,7 @@ class BaseVector(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 627d7c3aeb..ca90233b7f 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC): @staticmethod def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: - index_struct_dict = { - "type": vector_type, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} return index_struct_dict class Vector: def __init__(self, dataset: Dataset, attributes: list = None): if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page'] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -39,7 +36,7 @@ class Vector: def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] + vector_type = self._dataset.index_struct_dict["type"] if not vector_type: raise ValueError("Vector store must be specified.") @@ -52,45 +49,59 @@ class Vector: match vector_type: case VectorType.CHROMA: from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory + return ChromaVectorFactory case VectorType.MILVUS: from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory + return MilvusVectorFactory case VectorType.MYSCALE: from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory + return MyScaleVectorFactory case VectorType.PGVECTOR: from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory + return PGVectorFactory case VectorType.PGVECTO_RS: from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory + return PGVectoRSFactory case VectorType.QDRANT: from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory + return QdrantVectorFactory case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory + return RelytVectorFactory case VectorType.ELASTICSEARCH: from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory + return TiDBVectorFactory case VectorType.WEAVIATE: from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory + return WeaviateVectorFactory case VectorType.TENCENT: from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory + return TencentVectorFactory case VectorType.ORACLE: from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory + return OracleVectorFactory case VectorType.OPENSEARCH: from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory + return OpenSearchVectorFactory case VectorType.ANALYTICDB: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory + return AnalyticdbVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -98,22 +109,14 @@ class Vector: def create(self, texts: list = None, **kwargs): if texts: embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) - self._vector_processor.create( - texts=texts, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs) def add_texts(self, documents: list[Document], **kwargs): - if kwargs.get('duplicate_check', False): + if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) - self._vector_processor.create( - texts=documents, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs) def text_exists(self, id: str) -> bool: return self._vector_processor.text_exists(id) @@ -124,24 +127,18 @@ class Vector: def delete_by_metadata_field(self, key: str, value: str) -> None: self._vector_processor.delete_by_metadata_field(key, value) - def search_by_vector( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) def delete(self) -> None: self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: - collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: @@ -151,14 +148,13 @@ class Vector: tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model - + model=self._dataset.embedding_model, ) return CacheEmbedding(embedding_model) def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: - doc_id = text.metadata['doc_id'] + for text in texts.copy(): + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 317ca6abc8..ba04ea879d 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -2,17 +2,17 @@ from enum import Enum class VectorType(str, Enum): - ANALYTICDB = 'analyticdb' - CHROMA = 'chroma' - MILVUS = 'milvus' - MYSCALE = 'myscale' - PGVECTOR = 'pgvector' - PGVECTO_RS = 'pgvecto-rs' - QDRANT = 'qdrant' - RELYT = 'relyt' - TIDB_VECTOR = 'tidb_vector' - WEAVIATE = 'weaviate' - OPENSEARCH = 'opensearch' - TENCENT = 'tencent' - ORACLE = 'oracle' - ELASTICSEARCH = 'elasticsearch' + ANALYTICDB = "analyticdb" + CHROMA = "chroma" + MILVUS = "milvus" + MYSCALE = "myscale" + PGVECTOR = "pgvector" + PGVECTO_RS = "pgvecto-rs" + QDRANT = "qdrant" + RELYT = "relyt" + TIDB_VECTOR = "tidb_vector" + WEAVIATE = "weaviate" + OPENSEARCH = "opensearch" + TENCENT = "tencent" + ORACLE = "oracle" + ELASTICSEARCH = "elasticsearch" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 205fe850c3..6eee344b9b 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -22,15 +22,15 @@ class WeaviateConfig(BaseModel): api_key: Optional[str] = None batch_size: int = 100 - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: + if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") return values class WeaviateVector(BaseVector): - def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): super().__init__(collection_name) self._client = self._init_client(config) @@ -43,10 +43,7 @@ class WeaviateVector(BaseVector): try: client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None + url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) except requests.exceptions.ConnectionError: raise ConnectionError("Vector database connection error") @@ -68,10 +65,10 @@ class WeaviateVector(BaseVector): def get_collection_name(self, dataset: Dataset) -> str: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + if not class_prefix.endswith("_Node"): # original class_prefix - class_prefix += '_Node' + class_prefix += "_Node" return class_prefix @@ -79,10 +76,7 @@ class WeaviateVector(BaseVector): return Dataset.gen_collection_name_by_id(dataset_id) def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): # create collection @@ -91,9 +85,9 @@ class WeaviateVector(BaseVector): self.add_texts(texts, embeddings) def _create_collection(self): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return schema = self._default_schema(self._collection_name) @@ -129,17 +123,9 @@ class WeaviateVector(BaseVector): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): - where_filter = { - "operator": "Equal", - "path": [key], - "valueText": value - } + where_filter = {"operator": "Equal", "path": [key], "valueText": value} - self._client.batch.delete_objects( - class_name=self._collection_name, - where=where_filter, - output='minimal' - ) + self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal") def delete(self): # check whether the index already exists @@ -154,11 +140,19 @@ class WeaviateVector(BaseVector): # check whether the index already exists if not self._client.schema.contains(schema): return False - result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": id, - }).with_limit(1).do() + result = ( + self._client.query.get(collection_name) + .with_additional(["id"]) + .with_where( + { + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + } + ) + .with_limit(1) + .do() + ) if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -211,13 +205,13 @@ class WeaviateVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -240,15 +234,15 @@ class WeaviateVector(BaseVector): if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_additional(["vector"]) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + properties = ["text"] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - additional = res.pop('_additional') - docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) + additional = res.pop("_additional") + docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: @@ -271,20 +265,19 @@ class WeaviateVector(BaseVector): class WeaviateVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( endpoint=dify_config.WEAVIATE_ENDPOINT, api_key=dify_config.WEAVIATE_API_KEY, - batch_size=dify_config.WEAVIATE_BATCH_SIZE + batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), - attributes=attributes + attributes=attributes, ) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 96a15be742..319a2612c7 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -12,10 +12,10 @@ from models.dataset import Dataset, DocumentSegment class DatasetDocumentStore: def __init__( - self, - dataset: Dataset, - user_id: str, - document_id: Optional[str] = None, + self, + dataset: Dataset, + user_id: str, + document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id @@ -41,9 +41,9 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: - document_segments = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id - ).all() + document_segments = ( + db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + ) output = {} for document_segment in document_segments: @@ -55,48 +55,45 @@ class DatasetDocumentStore: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) return output - def add_documents( - self, docs: Sequence[Document], allow_update: bool = True - ) -> None: - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == self._document_id - ).scalar() + def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == self._document_id) + .scalar() + ) if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == 'high_quality': + if self._dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model + model=self._dataset.embedding_model, ) for doc in docs: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: raise ValueError( - f"doc_id {doc.metadata['doc_id']} already exists. " - "Set allow_update to True to overwrite." + f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite." ) # calc embedding use tokens if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[doc.page_content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) else: tokens = 0 @@ -107,8 +104,8 @@ class DatasetDocumentStore: tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, - index_node_id=doc.metadata['doc_id'], - index_node_hash=doc.metadata['doc_hash'], + index_node_id=doc.metadata["doc_id"], + index_node_hash=doc.metadata["doc_hash"], position=max_position, content=doc.page_content, word_count=len(doc.page_content), @@ -116,15 +113,15 @@ class DatasetDocumentStore: enabled=False, created_by=self._user_id, ) - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) else: segment_document.content = doc.page_content - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') - segment_document.index_node_hash = doc.metadata['doc_hash'] + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") + segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens @@ -135,9 +132,7 @@ class DatasetDocumentStore: result = self.get_document_segment(doc_id) return result is not None - def get_document( - self, doc_id: str, raise_error: bool = True - ) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -153,7 +148,7 @@ class DatasetDocumentStore: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) def delete_document(self, doc_id: str, raise_error: bool = True) -> None: @@ -188,9 +183,10 @@ class DatasetDocumentStore: return document_segment.index_node_hash def get_document_segment(self, doc_id: str) -> DocumentSegment: - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id, - DocumentSegment.index_node_id == doc_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .first() + ) return document_segment diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index abfdafcfa2..e46ab8b7fd 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -4,6 +4,7 @@ The goal is to facilitate decoupling of content loading from content parsing cod In addition, content loading code should provide a lazy loading interface by default. """ + from __future__ import annotations import contextlib @@ -11,7 +12,7 @@ import mimetypes from abc import ABC, abstractmethod from collections.abc import Generator, Iterable, Mapping from io import BufferedReader, BytesIO -from pathlib import PurePath +from pathlib import Path, PurePath from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, model_validator @@ -55,8 +56,7 @@ class Blob(BaseModel): def as_string(self) -> str: """Read data as a string.""" if self.data is None and self.path: - with open(str(self.path), encoding=self.encoding) as f: - return f.read() + return Path(str(self.path)).read_text(encoding=self.encoding) elif isinstance(self.data, bytes): return self.data.decode(self.encoding) elif isinstance(self.data, str): @@ -71,8 +71,7 @@ class Blob(BaseModel): elif isinstance(self.data, str): return self.data.encode(self.encoding) elif self.data is None and self.path: - with open(str(self.path), "rb") as f: - return f.read() + return Path(str(self.path)).read_bytes() else: raise ValueError(f"Unable to get bytes for blob {self}") diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 0470569f39..5b67403902 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import csv from typing import Optional @@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, ): """Initialize with file path.""" self._file_path = file_path @@ -57,7 +58,7 @@ class CSVExtractor(BaseExtractor): docs = [] try: # load csv file into pandas dataframe - df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args) + df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args) # check source column exists if self.source_column and self.source_column not in df.columns: @@ -67,7 +68,7 @@ class CSVExtractor(BaseExtractor): for i, row in df.iterrows(): content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns) - source = row[self.source_column] if self.source_column else '' + source = row[self.source_column] if self.source_column else "" metadata = {"source": source, "row": i} doc = Document(page_content=content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 7479b1d97b..3692b5d19d 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,6 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ + notion_workspace_id: str notion_obj_id: str notion_page_type: str @@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel): """ website import info. """ + provider: str job_id: str url: str @@ -43,6 +45,7 @@ class ExtractSetting(BaseModel): """ Model class for provider response. """ + datasource_type: str upload_file: Optional[UploadFile] = None notion_info: Optional[NotionInfo] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 526c66042c..fc33165719 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import os from typing import Optional @@ -17,23 +18,18 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: - """ Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" + """Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" documents = [] file_extension = os.path.splitext(self._file_path)[-1].lower() - if file_extension == '.xlsx': + if file_extension == ".xlsx": wb = load_workbook(self._file_path, data_only=True) for sheet_name in wb.sheetnames: sheet = wb[sheet_name] @@ -44,35 +40,38 @@ class ExcelExtractor(BaseExtractor): continue df = pd.DataFrame(data, columns=cols) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for index, row in df.iterrows(): page_content = [] for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): - cell = sheet.cell(row=index + 2, - column=col_index + 1) # +2 to account for header and 1-based index + cell = sheet.cell( + row=index + 2, column=col_index + 1 + ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" page_content.append(f'"{k}":"{value}"') else: page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) - elif file_extension == '.xls': - excel_file = pd.ExcelFile(self._file_path, engine='xlrd') + elif file_extension == ".xls": + excel_file = pd.ExcelFile(self._file_path, engine="xlrd") for sheet_name in excel_file.sheet_names: df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for _, row in df.iterrows(): page_content = [] for k, v in row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) else: raise ValueError(f"Unsupported file extension: {file_extension}") diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index f7a08135f5..fe7eaa32e6 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -29,61 +29,62 @@ from core.rag.models.document import Document from extensions.ext_storage import storage from models.model import UploadFile -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json'] -USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" +SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"] +USER_AGENT = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124" + " Safari/537.36" +) class ExtractProcessor: @classmethod - def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ - -> Union[list[Document], str]: + def load_from_upload_file( + cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False + ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=upload_file, - document_model='text_model' + datasource_type="upload_file", upload_file=upload_file, document_model="text_model" ) if return_text: - delimiter = '\n' + delimiter = "\n" return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) else: return cls.extract(extract_setting, is_automatic) @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = ssrf_proxy.get(url, headers={ - "User-Agent": USER_AGENT - }) + response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT}) with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix - if not suffix and suffix != '.': + if not suffix and suffix != ".": # get content-type - if response.headers.get('Content-Type'): - suffix = '.' + response.headers.get('Content-Type').split('/')[-1] + if response.headers.get("Content-Type"): + suffix = "." + response.headers.get("Content-Type").split("/")[-1] else: - content_disposition = response.headers.get('Content-Disposition') + content_disposition = response.headers.get("Content-Disposition") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = '.' + re.search(r'\.(\w+)$', filename).group(1) + suffix = "." + re.search(r"\.(\w+)$", filename).group(1) file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: - file.write(response.content) - extract_setting = ExtractSetting( - datasource_type="upload_file", - document_model='text_model' - ) + Path(file_path).write_bytes(response.content) + extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: - delimiter = '\n' - return delimiter.join([document.page_content for document in cls.extract( - extract_setting=extract_setting, file_path=file_path)]) + delimiter = "\n" + return delimiter.join( + [ + document.page_content + for document in cls.extract(extract_setting=extract_setting, file_path=file_path) + ] + ) else: return cls.extract(extract_setting=extract_setting, file_path=file_path) @classmethod - def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, - file_path: str = None) -> list[Document]: + def extract( + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None + ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: @@ -96,50 +97,56 @@ class ExtractProcessor: etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - if etl_type == 'Unstructured': - if file_extension == '.xlsx' or file_extension == '.xls': + if etl_type == "Unstructured": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: - extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + elif file_extension in {".md", ".markdown"}: + extractor = ( + UnstructuredMarkdownExtractor(file_path, unstructured_api_url) + if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + ) + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == '.msg': + elif file_extension == ".msg": extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) - elif file_extension == '.eml': + elif file_extension == ".eml": extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) - elif file_extension == '.ppt': + elif file_extension == ".ppt": extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) - elif file_extension == '.pptx': + elif file_extension == ".pptx": extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) - elif file_extension == '.xml': + elif file_extension == ".xml": extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url) else: # txt - extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + extractor = ( + UnstructuredTextExtractor(file_path, unstructured_api_url) + if is_automatic else TextExtractor(file_path, autodetect_encoding=True) + ) else: - if file_extension == '.xlsx' or file_extension == '.xls': + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: + elif file_extension in {".md", ".markdown"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path) else: # txt @@ -155,13 +162,13 @@ class ExtractProcessor: ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: - if extract_setting.website_info.provider == 'firecrawl': + if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, job_id=extract_setting.website_info.job_id, tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content + only_main_content=extract_setting.website_info.only_main_content, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py index c490e59332..582eca94df 100644 --- a/api/core/rag/extractor/extractor_base.py +++ b/api/core/rag/extractor/extractor_base.py @@ -1,12 +1,11 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod class BaseExtractor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self): raise NotImplementedError - diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 2b85ad9739..17c2087a0a 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -9,108 +9,98 @@ from extensions.ext_storage import storage class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' - if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': - raise ValueError('No API key provided') + self.base_url = base_url or "https://api.firecrawl.dev" + if self.api_key is None and self.base_url == "https://api.firecrawl.dev": + raise ValueError("No API key provided") def scrape_url(self, url, params=None) -> dict: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - json_data = {'url': url} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + json_data = {"url": url} if params: json_data.update(params) - response = requests.post( - f'{self.base_url}/v0/scrape', - headers=headers, - json=json_data - ) + response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: response = response.json() - if response['success'] == True: - data = response['data'] + if response["success"] == True: + data = response["data"] return { - 'title': data.get('metadata').get('title'), - 'description': data.get('metadata').get('description'), - 'source_url': data.get('metadata').get('sourceURL'), - 'markdown': data.get('markdown') + "title": data.get("metadata").get("title"), + "description": data.get("metadata").get("description"), + "source_url": data.get("metadata").get("sourceURL"), + "markdown": data.get("markdown"), } else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') - elif response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') + elif response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: headers = self._prepare_headers() - json_data = {'url': url} + json_data = {"url": url} if params: json_data.update(params) - response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: - job_id = response.json().get('jobId') + job_id = response.json().get("jobId") return job_id else: - self._handle_error(response, 'start crawl job') + self._handle_error(response, "start crawl job") def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() - response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) + response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers) if response.status_code == 200: crawl_status_response = response.json() - if crawl_status_response.get('status') == 'completed': - total = crawl_status_response.get('total', 0) + if crawl_status_response.get("status") == "completed": + total = crawl_status_response.get("total", 0) if total == 0: - raise Exception('Failed to check crawl status. Error: No page found') - data = crawl_status_response.get('data', []) + raise Exception("Failed to check crawl status. Error: No page found") + data = crawl_status_response.get("data", []) url_data_list = [] for item in data: - if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: + if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - 'title': item.get('metadata').get('title'), - 'description': item.get('metadata').get('description'), - 'source_url': item.get('metadata').get('sourceURL'), - 'markdown': item.get('markdown') + "title": item.get("metadata").get("title"), + "description": item.get("metadata").get("description"), + "source_url": item.get("metadata").get("sourceURL"), + "markdown": item.get("markdown"), } url_data_list.append(url_data) if url_data_list: - file_key = 'website_files/' + job_id + '.txt' + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) + storage.save(file_key, json.dumps(url_data_list).encode("utf-8")) return { - 'status': 'completed', - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': url_data_list + "status": "completed", + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": url_data_list, } else: return { - 'status': crawl_status_response.get('status'), - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': [] + "status": crawl_status_response.get("status"), + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": [], } else: - self._handle_error(response, 'check crawl status') + self._handle_error(response, "check crawl status") def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.post(url, headers=headers, json=data) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response @@ -119,13 +109,11 @@ class FirecrawlApp: for attempt in range(retries): response = requests.get(url, headers=headers) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response def _handle_error(self, response, action): - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') - - + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index 8e2f107e5e..b33ce167c2 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -5,7 +5,7 @@ from services.website_service import WebsiteService class FirecrawlWebExtractor(BaseExtractor): """ - Crawl and scrape websites and return content in clean llm-ready markdown. + Crawl and scrape websites and return content in clean llm-ready markdown. Args: @@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor): mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. """ - def __init__( - self, - url: str, - job_id: str, - tenant_id: str, - mode: str = 'crawl', - only_main_content: bool = False - ): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -33,28 +26,31 @@ class FirecrawlWebExtractor(BaseExtractor): def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] - if self.mode == 'crawl': - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) + if self.mode == "crawl": + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id) if crawl_data is None: return [] - document = Document(page_content=crawl_data.get('markdown', ''), - metadata={ - 'source_url': crawl_data.get('source_url'), - 'description': crawl_data.get('description'), - 'title': crawl_data.get('title') - } - ) + document = Document( + page_content=crawl_data.get("markdown", ""), + metadata={ + "source_url": crawl_data.get("source_url"), + "description": crawl_data.get("description"), + "title": crawl_data.get("title"), + }, + ) documents.append(document) - elif self.mode == 'scrape': - scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, - self.only_main_content) + elif self.mode == "scrape": + scrape_data = WebsiteService.get_scrape_url_data( + "firecrawl", self._url, self.tenant_id, self.only_main_content + ) - document = Document(page_content=scrape_data.get('markdown', ''), - metadata={ - 'source_url': scrape_data.get('source_url'), - 'description': scrape_data.get('description'), - 'title': scrape_data.get('title') - } - ) + document = Document( + page_content=scrape_data.get("markdown", ""), + metadata={ + "source_url": scrape_data.get("source_url"), + "description": scrape_data.get("description"), + "title": scrape_data.get("title"), + }, + ) documents.append(document) return documents diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 0c17a47b32..69ca9d5d63 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,6 +1,7 @@ """Document loader helpers.""" import concurrent.futures +from pathlib import Path from typing import NamedTuple, Optional, cast @@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding import chardet def read_and_detect(file_path: str) -> list[dict]: - with open(file_path, "rb") as f: - rawdata = f.read() + rawdata = Path(file_path).read_bytes() return cast(list[dict], chardet.detect_all(rawdata)) with concurrent.futures.ThreadPoolExecutor() as executor: @@ -37,9 +37,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding try: encodings = future.result(timeout=timeout) except concurrent.futures.TimeoutError: - raise TimeoutError( - f"Timeout reached while detecting encoding for {file_path}" - ) + raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") if all(encoding["encoding"] is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index ceb5306255..560c2d1d84 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from bs4 import BeautifulSoup from core.rag.extractor.extractor_base import BaseExtractor @@ -6,7 +7,6 @@ from core.rag.models.document import Document class HtmlExtractor(BaseExtractor): - """ Load html files. @@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str - ): + def __init__(self, file_path: str): """Initialize with file path.""" self._file_path = file_path @@ -27,8 +24,8 @@ class HtmlExtractor(BaseExtractor): def _load_as_text(self) -> str: with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') + soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() - text = text.strip() if text else '' + text = text.strip() if text else "" - return text \ No newline at end of file + return text diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index b24cf2e170..849852ac23 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,5 +1,7 @@ """Abstract interface for document loader implementations.""" + import re +from pathlib import Path from typing import Optional, cast from core.rag.extractor.extractor_base import BaseExtractor @@ -16,12 +18,12 @@ class MarkdownExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = False, - remove_images: bool = False, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = False, + remove_images: bool = False, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -78,13 +80,10 @@ class MarkdownExtractor(BaseExtractor): if current_header is not None: # pass linting, assert keys are defined markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) - for key, value in markdown_tups + (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] else: - markdown_tups = [ - (key, re.sub("\n", "", value)) for key, value in markdown_tups - ] + markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] return markdown_tups @@ -104,15 +103,13 @@ class MarkdownExtractor(BaseExtractor): """Parse file into tuples.""" content = "" try: - with open(filepath, encoding=self._encoding) as f: - content = f.read() + content = Path(filepath).read_text(encoding=self._encoding) except UnicodeDecodeError as e: if self._autodetect_encoding: detected_encodings = detect_file_encodings(filepath) for encoding in detected_encodings: try: - with open(filepath, encoding=encoding.encoding) as f: - content = f.read() + content = Path(filepath).read_text(encoding=encoding.encoding) break except UnicodeDecodeError: continue diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 7e839804c8..87a4ce08bf 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -21,22 +21,21 @@ RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" # if user want split by headings, use the corresponding splitter HEADING_SPLITTER = { - 'heading_1': '# ', - 'heading_2': '## ', - 'heading_3': '### ', + "heading_1": "# ", + "heading_2": "## ", + "heading_3": "### ", } + class NotionExtractor(BaseExtractor): - def __init__( - self, - notion_workspace_id: str, - notion_obj_id: str, - notion_page_type: str, - tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - + self, + notion_workspace_id: str, + notion_obj_id: str, + notion_page_type: str, + tenant_id: str, + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None, ): self._notion_access_token = None self._document_model = document_model @@ -46,46 +45,38 @@ class NotionExtractor(BaseExtractor): if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(tenant_id, - self._notion_workspace_id) + self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." + "Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`." ) self._notion_access_token = integration_token def extract(self) -> list[Document]: - self.update_last_edited_time( - self._document_model - ) + self.update_last_edited_time(self._document_model) text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) return text_docs - def _load_data_as_documents( - self, notion_obj_id: str, notion_page_type: str - ) -> list[Document]: + def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]: docs = [] - if notion_page_type == 'database': + if notion_page_type == "database": # get all the pages in the database page_text_documents = self._get_notion_database_data(notion_obj_id) docs.extend(page_text_documents) - elif notion_page_type == 'page': + elif notion_page_type == "page": page_text_list = self._get_notion_block_data(notion_obj_id) - docs.append(Document(page_content='\n'.join(page_text_list))) + docs.append(Document(page_content="\n".join(page_text_list))) else: raise ValueError("notion page type not supported") return docs - def _get_notion_database_data( - self, database_id: str, query_dict: dict[str, Any] = {} - ) -> list[Document]: + def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -100,50 +91,50 @@ class NotionExtractor(BaseExtractor): data = res.json() database_content = [] - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: return [] for result in data["results"]: - properties = result['properties'] + properties = result["properties"] data = {} for property_name, property_value in properties.items(): - type = property_value['type'] - if type == 'multi_select': + type = property_value["type"] + if type == "multi_select": value = [] multi_select_list = property_value[type] for multi_select in multi_select_list: - value.append(multi_select['name']) - elif type == 'rich_text' or type == 'title': + value.append(multi_select["name"]) + elif type in {"rich_text", "title"}: if len(property_value[type]) > 0: - value = property_value[type][0]['plain_text'] + value = property_value[type][0]["plain_text"] else: - value = '' - elif type == 'select' or type == 'status': + value = "" + elif type in {"select", "status"}: if property_value[type]: - value = property_value[type]['name'] + value = property_value[type]["name"] else: - value = '' + value = "" else: value = property_value[type] data[property_name] = value row_dict = {k: v for k, v in data.items() if v} - row_content = '' + row_content = "" for key, value in row_dict.items(): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} - value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items()) - row_content = row_content + f'{key}:{value_content}\n' + value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) + row_content = row_content + f"{key}:{value_content}\n" else: - row_content = row_content + f'{key}:{value}\n' + row_content = row_content + f"{key}:{value}\n" database_content.append(row_content) - return [Document(page_content='\n'.join(database_content))] + return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", block_url, @@ -152,14 +143,14 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) text += "\n\n" @@ -175,17 +166,15 @@ class NotionExtractor(BaseExtractor): result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -199,7 +188,7 @@ class NotionExtractor(BaseExtractor): start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -209,16 +198,16 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: break for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) result_lines_arr.append(text) @@ -233,17 +222,15 @@ class NotionExtractor(BaseExtractor): result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=num_tabs + 1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: - result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -260,7 +247,7 @@ class NotionExtractor(BaseExtractor): start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while not done: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -270,28 +257,28 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() # get table headers text table_header_cell_texts = [] - table_header_cells = data["results"][0]['table_row']['cells'] + table_header_cells = data["results"][0]["table_row"]["cells"] for table_header_cell in table_header_cells: if table_header_cell: for table_header_cell_text in table_header_cell: text = table_header_cell_text["text"]["content"] table_header_cell_texts.append(text) else: - table_header_cell_texts.append('') + table_header_cell_texts.append("") # Initialize Markdown table with headers markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n" - markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n" # Process data to format each row in Markdown table format results = data["results"] for i in range(len(results) - 1): column_texts = [] - table_column_cells = data["results"][i + 1]['table_row']['cells'] + table_column_cells = data["results"][i + 1]["table_row"]["cells"] for j in range(len(table_column_cells)): if table_column_cells[j]: for table_column_cell_text in table_column_cells[j]: @@ -315,10 +302,8 @@ class NotionExtractor(BaseExtractor): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - DocumentModel.data_source_info: json.dumps(data_source_info) - } + data_source_info["last_edited_time"] = last_edited_time + update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} DocumentModel.query.filter_by(id=document_model.id).update(update_params) db.session.commit() @@ -326,7 +311,7 @@ class NotionExtractor(BaseExtractor): def get_notion_last_edited_time(self) -> str: obj_id = self._notion_obj_id page_type = self._notion_page_type - if page_type == 'database': + if page_type == "database": retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) @@ -341,7 +326,7 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - json=query_dict + json=query_dict, ) data = res.json() @@ -352,14 +337,16 @@ class NotionExtractor(BaseExtractor): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) ).first() if not data_source_binding: - raise Exception(f'No notion data source binding found for tenant {tenant_id} ' - f'and notion workspace {notion_workspace_id}') + raise Exception( + f"No notion data source binding found for tenant {tenant_id} " + f"and notion workspace {notion_workspace_id}" + ) return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 0864fec6c8..57cb9610ba 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from collections.abc import Iterator from typing import Optional @@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - file_cache_key: Optional[str] = None - ): + def __init__(self, file_path: str, file_cache_key: Optional[str] = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key def extract(self) -> list[Document]: - plaintext_file_key = '' + plaintext_file_key = "" plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode('utf-8') + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -43,12 +40,12 @@ class PdfExtractor(BaseExtractor): # save plaintext file for caching if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) + storage.save(plaintext_file_key, text.encode("utf-8")) return documents def load( - self, + self, ) -> Iterator[Document]: """Lazy load given path as pages.""" blob = Blob.from_path(self._file_path) diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index ac5d0920cf..b2b51d71d7 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,4 +1,6 @@ """Abstract interface for document loader implementations.""" + +from pathlib import Path from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor @@ -14,12 +16,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding @@ -29,15 +26,13 @@ class TextExtractor(BaseExtractor): """Load from file path.""" text = "" try: - with open(self._file_path, encoding=self._encoding) as f: - text = f.read() + text = Path(self._file_path).read_text(encoding=self._encoding) except UnicodeDecodeError as e: if self._autodetect_encoding: detected_encodings = detect_file_encodings(self._file_path) for encoding in detected_encodings: try: - with open(self._file_path, encoding=encoding.encoding) as f: - text = f.read() + text = Path(self._file_path).read_text(encoding=encoding.encoding) break except UnicodeDecodeError: continue diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 0323b14a4a..a525c9e9e3 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -8,13 +8,12 @@ logger = logging.getLogger(__name__) class UnstructuredWordExtractor(BaseExtractor): - """Loader that uses unstructured to load word documents. - """ + """Loader that uses unstructured to load word documents.""" def __init__( - self, - file_path: str, - api_url: str, + self, + file_path: str, + api_url: str, ): """Initialize with file path.""" self._file_path = file_path @@ -24,9 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor): from unstructured.__version__ import __version__ as __unstructured_version__ from unstructured.file_utils.filetype import FileType, detect_filetype - unstructured_version = tuple( - int(x) for x in __unstructured_version__.split(".") - ) + unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: import magic # noqa: F401 @@ -53,6 +50,7 @@ class UnstructuredWordExtractor(BaseExtractor): elements = partition_docx(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2e704f187d..34c6811b67 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -26,6 +26,7 @@ class UnstructuredEmailExtractor(BaseExtractor): def extract(self) -> list[Document]: from unstructured.partition.email import partition_email + elements = partition_email(filename=self._file_path) # noinspection PyBroadException @@ -34,15 +35,16 @@ class UnstructuredEmailExtractor(BaseExtractor): element_text = element.text.strip() padding_needed = 4 - len(element_text) % 4 - element_text += '=' * padding_needed + element_text += "=" * padding_needed element_decode = base64.b64decode(element_text) - soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser') + soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() except Exception: pass from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 44cf958ea2..fa50fa76b2 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -28,6 +28,7 @@ class UnstructuredEpubExtractor(BaseExtractor): elements = partition_epub(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 144b4e0c1d..fc3ff10693 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -38,6 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): elements = partition_md(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index ad09b79eb0..8091e83e85 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredMsgExtractor(BaseExtractor): elements = partition_msg(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index d354b593ed..b69394b3b1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str, - api_key: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 6fcbb5feb9..6ed4a0dfb3 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py index f4a4adbc16..22dfdd2075 100644 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredTextExtractor(BaseExtractor): elements = partition_text(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 6aef8e0f7e..3bffc01fbf 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredXmlExtractor(BaseExtractor): elements = partition_xml(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 15822867bb..7352ef378b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import datetime import logging import mimetypes @@ -6,8 +7,8 @@ import os import re import tempfile import uuid -import xml.etree.ElementTree as ET from urllib.parse import urlparse +from xml.etree import ElementTree import requests from docx import Document as DocxDocument @@ -21,6 +22,7 @@ from models.model import UploadFile logger = logging.getLogger(__name__) + class WordExtractor(BaseExtractor): """Load docx files. @@ -43,12 +45,11 @@ class WordExtractor(BaseExtractor): r = requests.get(self.file_path) if r.status_code != 200: - raise ValueError( - f"Check the url of your file; returned status code {r.status_code}" - ) + raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path - self.temp_file = tempfile.NamedTemporaryFile() + # TODO: use a better way to handle the file + self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 self.temp_file.write(r.content) self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): @@ -60,11 +61,13 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, 'storage') - return [Document( - page_content=content, - metadata={"source": self.file_path}, - )] + content = self.parse_docx(self.file_path, "storage") + return [ + Document( + page_content=content, + metadata={"source": self.file_path}, + ) + ] @staticmethod def _is_valid_url(url: str) -> bool: @@ -84,18 +87,18 @@ class WordExtractor(BaseExtractor): url = rel.reltype response = requests.get(url, stream=True) if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers['Content-Type']) + image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) else: continue else: - image_ext = rel.target_ref.split('.')[-1] + image_ext = rel.target_ref.split(".")[-1] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) @@ -112,12 +115,14 @@ class WordExtractor(BaseExtractor): created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = ( + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + ) return image_map @@ -148,7 +153,7 @@ class WordExtractor(BaseExtractor): if col_index >= total_cols: break cell_content = self._parse_cell(cell, image_map).strip() - cell_colspan = cell.grid_span if cell.grid_span else 1 + cell_colspan = cell.grid_span or 1 for i in range(cell_colspan): if col_index + i < total_cols: row_cells[col_index + i] = cell_content if i == 0 else "" @@ -167,8 +172,8 @@ class WordExtractor(BaseExtractor): def _parse_cell_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if not image_id: continue @@ -184,16 +189,16 @@ class WordExtractor(BaseExtractor): def _parse_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): - embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if embed_id: rel_target = run.part.rels[embed_id].target_ref if rel_target in image_map: paragraph_content.append(image_map[rel_target]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ' '.join(paragraph_content) if paragraph_content else '' + return " ".join(paragraph_content) if paragraph_content else "" def parse_docx(self, docx_path, image_folder): doc = DocxDocument(docx_path) @@ -204,60 +209,59 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc, image_folder) hyperlinks_url = None - url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+') + url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") for para in doc.paragraphs: for run in para.runs: if run.text and hyperlinks_url: - result = f' [{run.text}]({hyperlinks_url}) ' + result = f" [{run.text}]({hyperlinks_url}) " run.text = result hyperlinks_url = None - if 'HYPERLINK' in run.element.xml: + if "HYPERLINK" in run.element.xml: try: - xml = ET.XML(run.element.xml) + xml = ElementTree.XML(run.element.xml) x_child = [c for c in xml.iter() if c is not None] for x in x_child: if x_child is None: continue - if x.tag.endswith('instrText'): + if x.tag.endswith("instrText"): for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: logger.error(e) - - - def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'): + if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"): drawing_elements = run.element.findall( - './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" + ) for drawing in drawing_elements: blip_elements = drawing.findall( - './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip') + ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" + ) for blip in blip_elements: embed_id = blip.get( - '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ''.join(paragraph_content) if paragraph_content else '' + return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() tables = doc.tables.copy() for element in doc.element.body: - if hasattr(element, 'tag'): - if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph + if hasattr(element, "tag"): + if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) if parsed_paragraph: content.append(parsed_paragraph) - elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table + elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) - return '\n'.join(content) - + return "\n".join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 630387fe3a..be857bd122 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod from typing import Optional @@ -15,8 +16,7 @@ from models.dataset import Dataset, DatasetProcessRule class BaseIndexProcessor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -34,18 +34,24 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule['mode'] == "custom": + if processing_rule["mode"] == "custom": # The user-defined segmentation rule - rules = processing_rule['rules'] + rules = processing_rule["rules"] segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: @@ -53,22 +59,22 @@ class BaseIndexProcessor(ABC): separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0) or 0, + chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index df43a64910..9b855ece2c 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -7,8 +7,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess class IndexProcessorFactory: - """IndexProcessorInit. - """ + """IndexProcessorInit.""" def __init__(self, index_type: str): self._index_type = index_type @@ -22,7 +21,6 @@ class IndexProcessorFactory: if self._index_type == IndexType.PARAGRAPH_INDEX.value: return ParagraphIndexProcessor() elif self._index_type == IndexType.QA_INDEX.value: - return QAIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index bd7f6093bd..ed5712220f 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import uuid from typing import Optional @@ -15,33 +16,32 @@ from models.dataset import Dataset class ParagraphIndexProcessor(BaseIndexProcessor): - def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -55,7 +55,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return all_documents def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: @@ -63,7 +63,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.create(documents) def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -76,17 +76,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.delete() - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: # Set search parameters. - results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index a44fd98036..1dbc473281 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import logging import re import threading @@ -23,33 +24,33 @@ from models.dataset import Dataset class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) # Split the text documents into nodes. all_documents = [] all_qa_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -61,14 +62,18 @@ class QAIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': kwargs.get('tenant_id'), - 'document_node': doc, - 'all_qa_documents': all_qa_documents, - 'document_language': kwargs.get('doc_language', 'English')}) + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": kwargs.get("tenant_id"), + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -76,9 +81,8 @@ class QAIndexProcessor(BaseIndexProcessor): return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: - # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -86,7 +90,7 @@ class QAIndexProcessor(BaseIndexProcessor): df = pd.read_csv(file) text_docs = [] for index, row in df.iterrows(): - data = Document(page_content=row[0], metadata={'answer': row[1]}) + data = Document(page_content=row[0], metadata={"answer": row[1]}) text_docs.append(data) if len(text_docs) == 0: raise ValueError("The CSV file is empty.") @@ -96,7 +100,7 @@ class QAIndexProcessor(BaseIndexProcessor): return text_docs def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) @@ -107,17 +111,29 @@ class QAIndexProcessor(BaseIndexProcessor): else: vector.delete() - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict): + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ): # Set search parameters. - results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) @@ -134,12 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor): document_qa_list = self._format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -151,10 +167,4 @@ class QAIndexProcessor(BaseIndexProcessor): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 6f3c1c5d34..0ff1fdb81c 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -55,9 +55,7 @@ class BaseDocumentTransformer(ABC): """ @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform a list of documents. Args: @@ -68,9 +66,7 @@ class BaseDocumentTransformer(ABC): """ @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/constants/rerank_mode.py index afbb9fd89d..d4894e3cc6 100644 --- a/api/core/rag/rerank/constants/rerank_mode.py +++ b/api/core/rag/rerank/constants/rerank_mode.py @@ -2,7 +2,5 @@ from enum import Enum class RerankMode(Enum): - - RERANKING_MODEL = 'reranking_model' - WEIGHTED_SCORE = 'weighted_score' - + RERANKING_MODEL = "reranking_model" + WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index d9067da288..6356ff87ab 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -8,8 +8,14 @@ class RerankModelRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -23,19 +29,15 @@ class RerankModelRunner: doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) documents = unique_documents rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) rerank_documents = [] @@ -45,12 +47,12 @@ class RerankModelRunner: rerank_document = Document( page_content=result.text, metadata={ - "doc_id": documents[result.index].metadata['doc_id'], - "doc_hash": documents[result.index].metadata['doc_hash'], - "document_id": documents[result.index].metadata['document_id'], - "dataset_id": documents[result.index].metadata['dataset_id'], - 'score': result.score - } + "doc_id": documents[result.index].metadata["doc_id"], + "doc_hash": documents[result.index].metadata["doc_hash"], + "document_id": documents[result.index].metadata["document_id"], + "dataset_id": documents[result.index].metadata["dataset_id"], + "score": result.score, + }, ) rerank_documents.append(rerank_document) diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d8a7873982..16d6b879a4 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -13,13 +13,18 @@ from core.rag.rerank.entity.weight import VectorSetting, Weights class WeightRerankRunner: - def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -34,8 +39,8 @@ class WeightRerankRunner: doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -47,13 +52,15 @@ class WeightRerankRunner: query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): # format document - score = self.weights.vector_setting.vector_weight * query_vector_score + \ - self.weights.keyword_setting.keyword_weight * query_score + score = ( + self.weights.vector_setting.vector_weight * query_vector_score + + self.weights.keyword_setting.keyword_weight * query_score + ) if score_threshold and score < score_threshold: continue - document.metadata['score'] = score + document.metadata["score"] = score rerank_documents.append(document) - rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True) + rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -70,7 +77,7 @@ class WeightRerankRunner: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -113,8 +120,8 @@ class WeightRerankRunner: intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: @@ -132,8 +139,9 @@ class WeightRerankRunner: return similarities - def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document], - vector_setting: VectorSetting) -> list[float]: + def _calculate_cosine( + self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting + ) -> list[float]: """ Calculate Cosine scores :param query: search query @@ -149,15 +157,14 @@ class WeightRerankRunner: tenant_id=tenant_id, provider=vector_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=vector_setting.embedding_model_name - + model=vector_setting.embedding_model_name, ) cache_embedding = CacheEmbedding(embedding_model) query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if 'score' in document.metadata: - query_vector_scores.append(document.metadata['score']) + if "score" in document.metadata: + query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy vec1 = np.array(query_vector) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index db01652f89..124c58f0fe 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -48,15 +45,18 @@ class DatasetRetrieval: self.application_generate_entity = application_generate_entity def retrieve( - self, app_id: str, user_id: str, tenant_id: str, - model_config: ModelConfigWithCredentialsEntity, - config: DatasetEntity, - query: str, - invoke_from: InvokeFrom, - show_retrieve_source: bool, - hit_callback: DatasetIndexToolCallbackHandler, - message_id: str, - memory: Optional[TokenBufferMemory] = None, + self, + app_id: str, + user_id: str, + tenant_id: str, + model_config: ModelConfigWithCredentialsEntity, + config: DatasetEntity, + query: str, + invoke_from: InvokeFrom, + show_retrieve_source: bool, + hit_callback: DatasetIndexToolCallbackHandler, + message_id: str, + memory: Optional[TokenBufferMemory] = None, ) -> Optional[str]: """ Retrieve dataset. @@ -84,16 +84,12 @@ class DatasetRetrieval: model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.model + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if not model_schema: @@ -102,39 +98,46 @@ class DatasetRetrieval: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) all_documents = [] - user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' + user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( - app_id, tenant_id, user_id, user_from, available_datasets, query, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, model_instance, - model_config, planning_strategy, message_id + model_config, + planning_strategy, + message_id, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( - app_id, tenant_id, user_id, user_from, - available_datasets, query, retrieve_config.top_k, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, + retrieve_config.top_k, retrieve_config.score_threshold, retrieve_config.rerank_mode, retrieve_config.reranking_model, @@ -145,89 +148,89 @@ class DatasetRetrieval: document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if show_retrieve_source: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ).first() - document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': invoke_from.to_source(), - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": invoke_from.to_source(), + "score": document_score_list.get(segment.index_node_id, None), } - if invoke_from.to_source() == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if invoke_from.to_source() == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 if hit_callback: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) - return '' + return "" def single_retrieve( - self, app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - model_instance: ModelInstance, - model_config: ModelConfigWithCredentialsEntity, - planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + message_id: Optional[str] = None, ): tools = [] for dataset in available_datasets: description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") message_tool = PromptMessageTool( name=dataset.id, description=description, @@ -235,14 +238,15 @@ class DatasetRetrieval: "type": "object", "properties": {}, "required": [], - } + }, ) tools.append(message_tool) dataset_id = None if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, - user_id, tenant_id) + dataset_id = react_multi_dataset_router.invoke( + query, tools, model_config, model_instance, user_id, tenant_id + ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() @@ -250,37 +254,37 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get retrieval method if dataset.indexing_technique == "economy": - retrieval_method = 'keyword_search' + retrieval_method = "keyword_search" else: - retrieval_method = retrieval_model_config['search_method'] + retrieval_method = retrieval_model_config["search_method"] # get reranking model - reranking_model = retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None + reranking_model = ( + retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None + ) # get score threshold - score_threshold = .0 + score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") with measure_time() as timer: results = RetrievalService.retrieve( - retrieval_method=retrieval_method, dataset_id=dataset.id, + retrieval_method=retrieval_method, + dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, + top_k=top_k, + score_threshold=score_threshold, reranking_model=reranking_model, - reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), - weights=retrieval_model_config.get('weights', None), + reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), + weights=retrieval_model_config.get("weights", None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) @@ -291,20 +295,20 @@ class DatasetRetrieval: return [] def multiple_retrieve( - self, - app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - top_k: int, - score_threshold: float, - reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - reranking_enable: bool = True, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reranking_enable: bool = True, + message_id: Optional[str] = None, ): threads = [] all_documents = [] @@ -312,13 +316,16 @@ class DatasetRetrieval: index_type = None for dataset in available_datasets: index_type = dataset.indexing_technique - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -327,16 +334,10 @@ class DatasetRetrieval: with measure_time() as timer: if reranking_enable: # do rerank for searched documents - data_post_processor = DataPostProcessor( - tenant_id, reranking_mode, - reranking_model, weights, False - ) + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) else: if index_type == "economy": @@ -357,30 +358,26 @@ class DatasetRetrieval: """Handle retrieval end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None + trace_manager: TraceQueueManager = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, - message_id=message_id, - documents=documents, - timer=timer + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer ) ) @@ -395,10 +392,10 @@ class DatasetRetrieval: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=app_id, created_by_role=user_from, - created_by=user_id + created_by=user_id, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -407,50 +404,50 @@ class DatasetRetrieval: def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + ) if documents: all_documents.extend(documents) else: if top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) all_documents.extend(documents) - def to_dataset_retriever_tool(self, tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[DatasetRetrieverBaseTool]]: + def to_dataset_retriever_tool( + self, + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -464,18 +461,14 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) @@ -483,22 +476,18 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get score threshold score_threshold = None @@ -512,7 +501,7 @@ class DatasetRetrieval: score_threshold=score_threshold, hit_callbacks=[hit_callback], return_resource=return_resource, - retriever_from=invoke_from.to_source() + retriever_from=invoke_from.to_source(), ) tools.append(tool) @@ -525,8 +514,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), - reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), ) tools.append(tool) @@ -547,7 +536,7 @@ class DatasetRetrieval: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -590,8 +579,8 @@ class DatasetRetrieval: intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: @@ -606,21 +595,19 @@ class DatasetRetrieval: for document, score in zip(documents, similarities): # format document - document.metadata['score'] = score - documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True) + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents[:top_k] if top_k else documents - def calculate_vector_score(self, all_documents: list[Document], - top_k: int, score_threshold: float) -> list[Document]: + def calculate_vector_score( + self, all_documents: list[Document], top_k: int, score_threshold: float + ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata["score"] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) + filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) return filter_documents[:top_k] if top_k else filter_documents - - - diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py index 60770bd4c6..7fc78bce83 100644 --- a/api/core/rag/retrieval/output_parser/structured_chat.py +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -16,9 +16,7 @@ class StructuredChatOutputParser: if response["action"] == "Final Answer": return ReactFinish({"output": response["action_input"]}, text) else: - return ReactAction( - response["action"], response.get("action_input", {}), text - ) + return ReactAction(response["action"], response.get("action_input", {}), text) else: return ReactFinish({"output": text}, text) except Exception as e: diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index 12aa28a51c..eaa00bca88 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -2,9 +2,9 @@ from enum import Enum class RetrievalMethod(Enum): - SEMANTIC_SEARCH = 'semantic_search' - FULL_TEXT_SEARCH = 'full_text_search' - HYBRID_SEARCH = 'hybrid_search' + SEMANTIC_SEARCH = "semantic_search" + FULL_TEXT_SEARCH = "full_text_search" + HYBRID_SEARCH = "hybrid_search" @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 84e53952ac..06147fe7b5 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -6,14 +6,12 @@ from core.model_runtime.entities.message_entities import PromptMessageTool, Syst class FunctionCallMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -26,22 +24,18 @@ class FunctionCallMultiDatasetRouter: try: prompt_messages = [ - SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=query) + SystemPromptMessage(content="You are a helpful AI assistant."), + UserPromptMessage(content=query), ] result = model_instance.invoke_llm( prompt_messages=prompt_messages, tools=dataset_tools, stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) if result.message.tool_calls: # get retrieval model config return result.message.tool_calls[0].function.name return None except Exception as e: - return None \ No newline at end of file + return None diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 92f24277c1..a0494adc60 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -14,7 +14,7 @@ from core.workflow.nodes.llm.llm_node import LLMNode PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Thought:""" +Thought:""" # noqa: E501 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. @@ -46,20 +46,18 @@ Action: "action": "Final Answer", "action_input": "Final response to human" }} -```""" +```""" # noqa: E501 class ReactMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - user_id: str, - tenant_id: str - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + user_id: str, + tenant_id: str, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -71,23 +69,28 @@ class ReactMultiDatasetRouter: return dataset_tools[0].name try: - return self._react_invoke(query=query, model_config=model_config, - model_instance=model_instance, - tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) + return self._react_invoke( + query=query, + model_config=model_config, + model_instance=model_instance, + tools=dataset_tools, + user_id=user_id, + tenant_id=tenant_id, + ) except Exception as e: return None def _react_invoke( - self, - query: str, - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - tools: Sequence[PromptMessageTool], - user_id: str, - tenant_id: str, - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + tools: Sequence[PromptMessageTool], + user_id: str, + tenant_id: str, + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -103,18 +106,18 @@ class ReactMultiDatasetRouter: prefix=prefix, format_instructions=format_instructions, ) - stop = ['Observation:'] + stop = ["Observation:"] # handle invoke result prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=None, memory=None, - model_config=model_config + model_config=model_config, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -122,7 +125,7 @@ class ReactMultiDatasetRouter: prompt_messages=prompt_messages, stop=stop, user_id=user_id, - tenant_id=tenant_id + tenant_id=tenant_id, ) output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) @@ -130,17 +133,21 @@ class ReactMultiDatasetRouter: return react_decision.tool return None - def _invoke_llm(self, completion_param: dict, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str - ) -> tuple[str, LLMUsage]: + def _invoke_llm( + self, + completion_param: dict, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str], + user_id: str, + tenant_id: str, + ) -> tuple[str, LLMUsage]: """ - Invoke large language model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: + Invoke large language model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, @@ -151,9 +158,7 @@ class ReactMultiDatasetRouter: ) # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) + text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) @@ -168,7 +173,7 @@ class ReactMultiDatasetRouter: """ model = None prompt_messages = [] - full_text = '' + full_text = "" usage = None for result in invoke_result: text = result.delta.message.content @@ -189,40 +194,36 @@ class ReactMultiDatasetRouter: return full_text, usage def create_chat_prompt( - self, - query: str, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: tool_strings.append( - f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query'," + f" 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}" + ) formatted_tools = "\n".join(tool_strings) unique_tool_names = {tool.name for tool in tools} tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) prompt_messages = [] - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=template - ) + system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template) prompt_messages.append(system_prompt_messages) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=query - ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query) prompt_messages.append(user_prompt_message) return prompt_messages def create_completion_prompt( - self, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. @@ -236,7 +237,7 @@ class ReactMultiDatasetRouter: suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. Question: {input} Thought: {agent_scratchpad} -""" +""" # noqa: E501 tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0c1cb57c7f..53032b34d5 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,4 +1,5 @@ """Functionality for splitting text.""" + from __future__ import annotations from typing import Any, Optional @@ -18,31 +19,29 @@ from core.rag.splitter.text_splitter import ( class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ - This class is used to implement from_gpt2_encoder, to prevent using of tiktoken + This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( - cls: type[TS], - embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + embedding_model_instance: Optional[ModelInstance], + allowed_special: Union[Literal[all], Set[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", + **kwargs: Any, ): def _token_encoder(text: str) -> int: if not text: return 0 if embedding_model_instance: - return embedding_model_instance.get_text_embedding_num_tokens( - texts=[text] - ) + return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) else: return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', + "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "allowed_special": allowed_special, "disallowed_special": disallowed_special, } diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index f06f22a00e..7dd62f8de1 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -22,9 +22,7 @@ logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") -def _split_text_with_regex( - text: str, separator: str, keep_separator: bool -) -> list[str]: +def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: @@ -37,19 +35,19 @@ def _split_text_with_regex( splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if (s != "" and s != '\n')] + return [s for s in splits if (s not in {"", "\n"})] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( - self, - chunk_size: int = 4000, - chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, - keep_separator: bool = False, - add_start_index: bool = False, + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, ) -> None: """Create a new TextSplitter. @@ -62,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): """ if chunk_overlap > chunk_size: raise ValueError( - f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " - f"({chunk_size}), should be smaller." + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap @@ -75,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents( - self, texts: list[str], metadatas: Optional[list[dict]] = None - ) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -119,14 +114,10 @@ class TextSplitter(BaseDocumentTransformer, ABC): index = 0 for d in splits: _len = lengths[index] - if ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - ): + if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( - f"Created a chunk of size {total}, " - f"which is longer than the specified {self._chunk_size}" + f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) @@ -136,13 +127,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - and total > 0 + total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): - total -= self._length_function(current_doc[0]) + ( - separator_len if len(current_doc) > 1 else 0 - ) + total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) @@ -159,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC): from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - "Tokenizer received was not an instance of PreTrainedTokenizerBase" - ) + raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( - "Could not import transformers python package. " - "Please install it with `pip install transformers`." + "Could not import transformers python package. Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( - cls: type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -217,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC): return cls(length_function=_tiktoken_encoder, **kwargs) - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError @@ -267,9 +247,7 @@ class HeaderType(TypedDict): class MarkdownHeaderTextSplitter: """Splitting markdown files based on specified headers.""" - def __init__( - self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False - ): + def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): """Create a new MarkdownHeaderTextSplitter. Args: @@ -280,9 +258,7 @@ class MarkdownHeaderTextSplitter: self.return_each_line = return_each_line # Given the headers we want to split on, # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted( - headers_to_split_on, key=lambda split: len(split[0]), reverse=True - ) + self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: """Combine lines with common metadata into chunks @@ -292,10 +268,7 @@ class MarkdownHeaderTextSplitter: aggregated_chunks: list[LineType] = [] for line in lines: - if ( - aggregated_chunks - and aggregated_chunks[-1]["metadata"] == line["metadata"] - ): + if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: # If the last line in the aggregated list # has the same metadata as the current line, # append the current content to the last lines's content @@ -304,10 +277,7 @@ class MarkdownHeaderTextSplitter: # Otherwise, append the current line to the aggregated list aggregated_chunks.append(line) - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in aggregated_chunks - ] + return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] def split_text(self, text: str) -> list[Document]: """Split markdown file @@ -332,10 +302,9 @@ class MarkdownHeaderTextSplitter: for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) - or stripped_line[len(sep)] == " " + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " ): # Ensure we are tracking the header as metadata if name is not None: @@ -343,10 +312,7 @@ class MarkdownHeaderTextSplitter: current_header_level = sep.count("#") # Pop out headers of lower or same level from the stack - while ( - header_stack - and header_stack[-1]["level"] >= current_header_level - ): + while header_stack and header_stack[-1]["level"] >= current_header_level: # We have encountered a new header # at the same or higher level popped_header = header_stack.pop() @@ -359,7 +325,7 @@ class MarkdownHeaderTextSplitter: header: HeaderType = { "level": current_header_level, "name": name, - "data": stripped_line[len(sep):].strip(), + "data": stripped_line[len(sep) :].strip(), } header_stack.append(header) # Update initial_metadata with the current header @@ -392,9 +358,7 @@ class MarkdownHeaderTextSplitter: current_metadata = initial_metadata.copy() if current_content: - lines_with_metadata.append( - {"content": "\n".join(current_content), "metadata": current_metadata} - ) + lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) # lines_with_metadata has each line with associated header metadata # aggregate these into chunks based on common metadata @@ -402,8 +366,7 @@ class MarkdownHeaderTextSplitter: return self.aggregate_lines_to_chunks(lines_with_metadata) else: return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in lines_with_metadata + Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata ] @@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( - self, - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) @@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter): """ def __init__( - self, - separators: Optional[list[str]] = None, - keep_separator: bool = True, - **kwargs: Any, + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) @@ -508,7 +471,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): break if re.search(_s, text): separator = _s - new_separators = separators[i + 1:] + new_separators = separators[i + 1 :] break splits = _split_text_with_regex(text, separator, self._keep_separator) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 3f0e7285af..ae84338cdc 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -11,23 +11,23 @@ from core.tools.tool.tool import ToolParameter class UserTool(BaseModel): author: str - name: str # identifier - label: I18nObject # label + name: str # identifier + label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None labels: list[str] = Field(default_factory=list) -UserToolProviderTypeLiteral = Optional[Literal[ - 'builtin', 'api', 'workflow' -]] + +UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] + class UserToolProvider(BaseModel): id: str author: str - name: str # identifier + name: str # identifier description: I18nObject icon: str | dict - label: I18nObject # label + label: I18nObject # label type: ToolProviderType masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None @@ -41,26 +41,27 @@ class UserToolProvider(BaseModel): # overwrite tool parameter types for temp fix tools = jsonable_encoder(self.tools) for tool in tools: - if tool.get('parameters'): - for parameter in tool.get('parameters'): - if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: - parameter['type'] = 'files' + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: + parameter["type"] = "files" # ------------- return { - 'id': self.id, - 'author': self.author, - 'name': self.name, - 'description': self.description.to_dict(), - 'icon': self.icon, - 'label': self.label.to_dict(), - 'type': self.type.value, - 'team_credentials': self.masked_credentials, - 'is_team_authorization': self.is_team_authorization, - 'allow_delete': self.allow_delete, - 'tools': tools, - 'labels': self.labels, + "id": self.id, + "author": self.author, + "name": self.name, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, } + class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ProviderConfig] \ No newline at end of file + credentials: dict[str, ProviderConfig] diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 55e31e8c35..37a926697b 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None pt_BR: Optional[str] = None en_US: str @@ -19,8 +20,4 @@ class I18nObject(BaseModel): self.pt_BR = self.en_US def to_dict(self) -> dict: - return { - 'zh_Hans': self.zh_Hans, - 'en_US': self.en_US, - 'pt_BR': self.pt_BR - } + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR} diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index d18d27fb02..0c15b2a371 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -7,8 +7,10 @@ from core.tools.entities.tool_entities import ToolParameter class ApiToolBundle(BaseModel): """ - This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. + This class is used to store the schema information of an api based tool. + such as the url, the method, the parameters, etc. """ + # server_url server_url: str # method diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 2f25898f4e..b764ac62ec 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -9,27 +9,29 @@ from core.tools.entities.common_entities import I18nObject class ToolLabelEnum(Enum): - SEARCH = 'search' - IMAGE = 'image' - VIDEOS = 'videos' - WEATHER = 'weather' - FINANCE = 'finance' - DESIGN = 'design' - TRAVEL = 'travel' - SOCIAL = 'social' - NEWS = 'news' - MEDICAL = 'medical' - PRODUCTIVITY = 'productivity' - EDUCATION = 'education' - BUSINESS = 'business' - ENTERTAINMENT = 'entertainment' - UTILITIES = 'utilities' - OTHER = 'other' + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + class ToolProviderType(str, Enum): """ - Enum class for tool provider + Enum class for tool provider """ + BUILT_IN = "builtin" WORKFLOW = "workflow" API = "api" @@ -37,7 +39,7 @@ class ToolProviderType(str, Enum): DATASET_RETRIEVAL = "dataset-retrieval" @classmethod - def value_of(cls, value: str) -> 'ToolProviderType': + def value_of(cls, value: str) -> "ToolProviderType": """ Get value of given mode. @@ -47,19 +49,21 @@ class ToolProviderType(str, Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. """ + OPENAPI = "openapi" SWAGGER = "swagger" OPENAI_PLUGIN = "openai_plugin" OPENAI_ACTIONS = "openai_actions" @classmethod - def value_of(cls, value: str) -> 'ApiProviderSchemaType': + def value_of(cls, value: str) -> "ApiProviderSchemaType": """ Get value of given mode. @@ -69,17 +73,19 @@ class ApiProviderSchemaType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. """ + NONE = "none" API_KEY = "api_key" @classmethod - def value_of(cls, value: str) -> 'ApiProviderAuthType': + def value_of(cls, value: str) -> "ApiProviderAuthType": """ Get value of given mode. @@ -89,7 +95,8 @@ class ApiProviderAuthType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ToolInvokeMessage(BaseModel): class TextMessage(BaseModel): @@ -107,7 +114,7 @@ class ToolInvokeMessage(BaseModel): stream: bool = Field(default=False, description="Whether the variable is streamed") @field_validator("variable_value", mode="before") - def transform_variable_value(cls, value, values) -> Any: + def transform_variable_value(self, value, values) -> Any: """ Only basic types and lists are allowed. """ @@ -122,11 +129,11 @@ class ToolInvokeMessage(BaseModel): return value @field_validator("variable_name", mode="before") - def transform_variable_name(cls, value) -> str: + def transform_variable_name(self, value) -> str: """ The variable name must be a string. """ - if value in ["json", "text", "files"]: + if value in {"json", "text", "files"}: raise ValueError(f"The variable name '{value}' is reserved.") return value @@ -146,7 +153,7 @@ class ToolInvokeMessage(BaseModel): """ message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None meta: dict[str, Any] | None = None - save_as: str = '' + save_as: str = "" @field_validator('message', mode='before') @classmethod @@ -166,17 +173,19 @@ class ToolInvokeMessage(BaseModel): } return v + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - save_as: str = '' + save_as: str = "" file_var: Optional[dict[str, Any]] = None + class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - @field_validator('value', mode='before') + @field_validator("value", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -195,9 +204,9 @@ class ToolParameter(BaseModel): FILE = CommonParameterType.FILE.value class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") @@ -214,21 +223,28 @@ class ToolParameter(BaseModel): options: list[ToolParameterOption] = Field(default_factory=list) @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, - required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': + def get_simple_instance( + cls, + name: str, + llm_description: str, + type: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": """ - get a simple tool parameter + get a simple tool parameter - :param name: the name of the parameter - :param llm_description: the description presented to the LLM - :param type: the type of the parameter - :param required: if the parameter is required - :param options: the options of the parameter + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter """ # convert options to ToolParameterOption if options: - option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] + option_objs = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ] else: option_objs = [] return cls( @@ -243,18 +259,24 @@ class ToolParameter(BaseModel): options=option_objs, ) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", ) + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") + class ToolIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -262,22 +284,27 @@ class ToolIdentity(BaseModel): provider: str = Field(..., description="The provider of the tool") icon: Optional[str] = None + class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" + class ToolRuntimeVariable(BaseModel): type: ToolRuntimeVariableType = Field(..., description="The type of the variable") name: str = Field(..., description="The name of the variable") position: int = Field(..., description="The position of the variable") tool_name: str = Field(..., description="The name of the tool") + class ToolRuntimeTextVariable(ToolRuntimeVariable): value: str = Field(..., description="The value of the variable") + class ToolRuntimeImageVariable(ToolRuntimeVariable): value: str = Field(..., description="The path of the image") + class ToolRuntimeVariablePool(BaseModel): conversation_id: str = Field(..., description="The conversation id") user_id: str = Field(..., description="The user id") @@ -286,26 +313,26 @@ class ToolRuntimeVariablePool(BaseModel): pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables") def __init__(self, **data: Any): - pool = data.get('pool', []) + pool = data.get("pool", []) # convert pool into correct type for index, variable in enumerate(pool): - if variable['type'] == ToolRuntimeVariableType.TEXT.value: + if variable["type"] == ToolRuntimeVariableType.TEXT.value: pool[index] = ToolRuntimeTextVariable(**variable) - elif variable['type'] == ToolRuntimeVariableType.IMAGE.value: + elif variable["type"] == ToolRuntimeVariableType.IMAGE.value: pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) def dict(self) -> dict: return { - 'conversation_id': self.conversation_id, - 'user_id': self.user_id, - 'tenant_id': self.tenant_id, - 'pool': [variable.model_dump() for variable in self.pool], + "conversation_id": self.conversation_id, + "user_id": self.user_id, + "tenant_id": self.tenant_id, + "pool": [variable.model_dump() for variable in self.pool], } def set_text(self, tool_name: str, name: str, value: str) -> None: """ - set a text variable + set a text variable """ for variable in self.pool: if variable.name == name: @@ -326,10 +353,10 @@ class ToolRuntimeVariablePool(BaseModel): def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None: """ - set an image variable + set an image variable - :param tool_name: the name of the tool - :param value: the id of the file + :param tool_name: the name of the tool + :param value: the id of the file """ # check how many image variables are there image_variable_count = 0 @@ -357,22 +384,27 @@ class ToolRuntimeVariablePool(BaseModel): self.pool.append(variable) + class ModelToolPropertyKey(Enum): IMAGE_PARAMETER_NAME = "image_parameter_name" + class ModelToolConfiguration(BaseModel): """ Model tool configuration """ + type: str = Field(..., description="The type of the model tool") model: str = Field(..., description="The model") label: I18nObject = Field(..., description="The label of the model tool") properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + class ModelToolProviderConfiguration(BaseModel): """ Model tool provider configuration """ + provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") @@ -382,27 +414,30 @@ class WorkflowToolParameterConfiguration(BaseModel): """ Workflow tool configuration """ + name: str = Field(..., description="The name of the parameter") description: str = Field(..., description="The description of the parameter") form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta """ + time_cost: float = Field(..., description="The time cost of the tool invoke") error: Optional[str] = None tool_config: Optional[dict] = None @classmethod - def empty(cls) -> 'ToolInvokeMeta': + def empty(cls) -> "ToolInvokeMeta": """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> 'ToolInvokeMeta': + def error_instance(cls, error: str) -> "ToolInvokeMeta": """ Get an instance of ToolInvokeMeta with error """ @@ -410,22 +445,26 @@ class ToolInvokeMeta(BaseModel): def to_dict(self) -> dict: return { - 'time_cost': self.time_cost, - 'error': self.error, - 'tool_config': self.tool_config, + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, } + class ToolLabel(BaseModel): """ Tool label """ + name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") icon: str = Field(..., description="The icon of the tool") + class ToolInvokeFrom(Enum): """ Enum class for tool invoke """ + WORKFLOW = "workflow" AGENT = "agent" diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py index d0be5e9355..f460df7e25 100644 --- a/api/core/tools/entities/values.py +++ b/api/core/tools/entities/values.py @@ -2,73 +2,109 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum ICONS = { - ToolLabelEnum.SEARCH: ''' + ToolLabelEnum.SEARCH: """ -''', - ToolLabelEnum.IMAGE: ''' +""", # noqa: E501 + ToolLabelEnum.IMAGE: """ -''', - ToolLabelEnum.VIDEOS: ''' +""", # noqa: E501 + ToolLabelEnum.VIDEOS: """ -''', - ToolLabelEnum.WEATHER: ''' +""", # noqa: E501 + ToolLabelEnum.WEATHER: """ -''', - ToolLabelEnum.FINANCE: ''' +""", # noqa: E501 + ToolLabelEnum.FINANCE: """ -''', - ToolLabelEnum.DESIGN: ''' +""", # noqa: E501 + ToolLabelEnum.DESIGN: """ -''', - ToolLabelEnum.TRAVEL: ''' +""", # noqa: E501 + ToolLabelEnum.TRAVEL: """ -''', - ToolLabelEnum.SOCIAL: ''' +""", # noqa: E501 + ToolLabelEnum.SOCIAL: """ -''', - ToolLabelEnum.NEWS: ''' +""", # noqa: E501 + ToolLabelEnum.NEWS: """ -''', - ToolLabelEnum.MEDICAL: ''' +""", # noqa: E501 + ToolLabelEnum.MEDICAL: """ -''', - ToolLabelEnum.PRODUCTIVITY: ''' +""", # noqa: E501 + ToolLabelEnum.PRODUCTIVITY: """ -''', - ToolLabelEnum.EDUCATION: ''' +""", # noqa: E501 + ToolLabelEnum.EDUCATION: """ -''', - ToolLabelEnum.BUSINESS: ''' +""", # noqa: E501 + ToolLabelEnum.BUSINESS: """ -''', - ToolLabelEnum.ENTERTAINMENT: ''' +""", # noqa: E501 + ToolLabelEnum.ENTERTAINMENT: """ -''', - ToolLabelEnum.UTILITIES: ''' +""", # noqa: E501 + ToolLabelEnum.UTILITIES: """ -''', - ToolLabelEnum.OTHER: ''' +""", # noqa: E501 + ToolLabelEnum.OTHER: """ -''' +""", # noqa: E501 } default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), - ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), - ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), - ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), - ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), - ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), - ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), - ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), - ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), - ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), - ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), - ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), - ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), - ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), - ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), - ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), } default_tool_labels = [v for k, v in default_tool_label_dict.items()] diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 9fd8322db1..6febf137b0 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -4,23 +4,30 @@ from core.tools.entities.tool_entities import ToolInvokeMeta class ToolProviderNotFoundError(ValueError): pass + class ToolNotFoundError(ValueError): pass + class ToolParameterValidationError(ValueError): pass + class ToolProviderCredentialValidationError(ValueError): pass + class ToolNotSupportedError(ValueError): pass + class ToolInvokeError(ValueError): pass + class ToolApiSchemaError(ValueError): pass + class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta \ No newline at end of file + meta: ToolInvokeMeta diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index 880ddc4955..307cc0a0d9 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,4 +1,3 @@ - from pydantic import Field from core.entities.provider_entities import ProviderConfig @@ -20,86 +19,70 @@ class ApiToolProviderController(ToolProviderController): tools: list[ApiTool] = Field(default_factory=list) @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = { - 'auth_type': ProviderConfig( - name='auth_type', + "auth_type": ProviderConfig( + name="auth_type", required=True, type=ProviderConfig.Type.SELECT, options=[ - ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='无')), - ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")), ], - default='none', - help=I18nObject( - en_US='The auth type of the api provider', - zh_Hans='api provider 的认证类型' - ) + default="none", + help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) } if auth_type == ApiProviderAuthType.API_KEY: credentials_schema = { **credentials_schema, - 'api_key_header': ProviderConfig( - name='api_key_header', + "api_key_header": ProviderConfig( + name="api_key_header", required=False, - default='api_key', + default="api_key", type=ProviderConfig.Type.TEXT_INPUT, - help=I18nObject( - en_US='The header name of the api key', - zh_Hans='携带 api key 的 header 名称' - ) + help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), - 'api_key_value': ProviderConfig( - name='api_key_value', + "api_key_value": ProviderConfig( + name="api_key_value", required=True, type=ProviderConfig.Type.SECRET_INPUT, - help=I18nObject( - en_US='The api key', - zh_Hans='api key的值' - ) + help=I18nObject(en_US="The api key", zh_Hans="api key的值"), ), - 'api_key_header_prefix': ProviderConfig( - name='api_key_header_prefix', + "api_key_header_prefix": ProviderConfig( + name="api_key_header_prefix", required=False, - default='basic', + default="basic", type=ProviderConfig.Type.SELECT, - help=I18nObject( - en_US='The prefix of the api key header', - zh_Hans='api key header 的前缀' - ), + help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"), options=[ - ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), - ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), - ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) - ] - ) + ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")), + ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")), + ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), + ], + ), } elif auth_type == ApiProviderAuthType.NONE: pass else: - raise ValueError(f'invalid auth type {auth_type}') + raise ValueError(f"invalid auth type {auth_type}") - user_name = db_provider.user.name if db_provider.user_id else '' + user_name = db_provider.user.name if db_provider.user_id else "" - return ApiToolProviderController(**{ - 'identity': { - 'author': user_name, - 'name': db_provider.name, - 'label': { - 'en_US': db_provider.name, - 'zh_Hans': db_provider.name + return ApiToolProviderController( + **{ + "identity": { + "author": user_name, + "name": db_provider.name, + "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, + "credentials_schema": credentials_schema, + "provider_id": db_provider.id or "", + "tenant_id": db_provider.tenant_id or "", }, - 'credentials_schema': credentials_schema, - 'provider_id': db_provider.id or '', - 'tenant_id': db_provider.tenant_id or '', - }) + ) @property def provider_type(self) -> ToolProviderType: @@ -107,39 +90,35 @@ class ApiToolProviderController(ToolProviderController): def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ - parse tool bundle to tool + parse tool bundle to tool - :param tool_bundle: the tool bundle - :return: the tool + :param tool_bundle: the tool bundle + :return: the tool """ - return ApiTool(**{ - 'api_bundle': tool_bundle, - 'identity' : { - 'author': tool_bundle.author, - 'name': tool_bundle.operation_id, - 'label': { - 'en_US': tool_bundle.operation_id, - 'zh_Hans': tool_bundle.operation_id + return ApiTool( + **{ + "api_bundle": tool_bundle, + "identity": { + "author": tool_bundle.author, + "name": tool_bundle.operation_id, + "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, + "icon": self.identity.icon, + "provider": self.provider_id, }, - 'icon': self.identity.icon, - 'provider': self.provider_id, - }, - 'description': { - 'human': { - 'en_US': tool_bundle.summary or '', - 'zh_Hans': tool_bundle.summary or '' + "description": { + "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, + "llm": tool_bundle.summary or "", }, - 'llm': tool_bundle.summary or '' - }, - 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], - }) + "parameters": tool_bundle.parameters or [], + } + ) def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ - load bundled tools + load bundled tools - :param tools: the bundled tools - :return: the tools + :param tools: the bundled tools + :return: the tools """ self.tools = [self._parse_tool_bundle(tool) for tool in tools] @@ -147,22 +126,23 @@ class ApiToolProviderController(ToolProviderController): def get_tools(self, tenant_id: str) -> list[ApiTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - + tools: list[ApiTool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == self.identity.name - ).all() + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) + .all() + ) if db_providers and len(db_providers) != 0: for db_provider in db_providers: @@ -170,16 +150,16 @@ class ApiToolProviderController(ToolProviderController): assistant_tool = self._parse_tool_bundle(tool) assistant_tool.is_team_authorization = True tools.append(assistant_tool) - + self.tools = tools return tools - + def get_tool(self, tool_name: str) -> ApiTool: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: self.get_tools(self.tenant_id) @@ -188,4 +168,4 @@ class ApiToolProviderController(ToolProviderController): if tool.identity.name == tool_name: return tool - raise ValueError(f'tool {tool_name} not found') \ No newline at end of file + raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py new file mode 100644 index 0000000000..09f328cd1f --- /dev/null +++ b/api/core/tools/provider/app_tool_provider.py @@ -0,0 +1,103 @@ +import logging +from typing import Any + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool +from extensions.ext_database import db +from models.model import App, AppModelConfig +from models.tools import PublishedAppTool + +logger = logging.getLogger(__name__) + + +class AppToolProviderEntity(ToolProviderController): + @property + def provider_type(self) -> ToolProviderType: + return ToolProviderType.APP + + def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: + pass + + def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: + pass + + def get_tools(self, user_id: str) -> list[Tool]: + db_tools: list[PublishedAppTool] = ( + db.session.query(PublishedAppTool) + .filter( + PublishedAppTool.user_id == user_id, + ) + .all() + ) + + if not db_tools or len(db_tools) == 0: + return [] + + tools: list[Tool] = [] + + for db_tool in db_tools: + tool = { + "identity": { + "author": db_tool.author, + "name": db_tool.tool_name, + "label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name}, + "icon": "", + }, + "description": { + "human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans}, + "llm": db_tool.llm_description, + }, + "parameters": [], + } + # get app from db + app: App = db_tool.app + + if not app: + logger.error(f"app {db_tool.app_id} not found") + continue + + app_model_config: AppModelConfig = app.app_model_config + user_input_form_list = app_model_config.user_input_form_list + for input_form in user_input_form_list: + # get type + form_type = input_form.keys()[0] + default = input_form[form_type]["default"] + required = input_form[form_type]["required"] + label = input_form[form_type]["label"] + variable_name = input_form[form_type]["variable_name"] + options = input_form[form_type].get("options", []) + if form_type in {"paragraph", "text-input"}: + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.STRING, + required=required, + default=default, + ) + ) + elif form_type == "select": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.SELECT, + required=required, + default=default, + options=[ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ], + ) + ) + + tools.append(Tool(**tool)) + return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 062668fc5b..5c10f72fda 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -10,7 +10,7 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), "..")) def name_func(provider: UserToolProvider) -> str: return provider.name diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py index 25133c51df..e0cbbd2992 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.py +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -6,6 +6,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class AIPPTProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') + AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__") except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 8d6883a3b1..dd9371f70d 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -20,16 +20,16 @@ class AIPPTGenerateTool(BuiltinTool): A tool for generating a ppt """ - _api_base_url = URL('https://co.aippt.cn/api') + _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock:Optional[Lock] = None + _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock:Optional[Lock] = None + _style_cache_lock: Optional[Lock] = None _task = {} _task_type_map = { - 'auto': 1, - 'markdown': 7, + "auto": 1, + "markdown": 7, } def __init__(self, **kwargs: Any): @@ -46,67 +46,58 @@ class AIPPTGenerateTool(BuiltinTool): tool_parameters (dict[str, Any]): The parameters for the tool Returns: - ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, + which can be a single message or a list of messages. """ - title = tool_parameters.get('title', '') + title = tool_parameters.get("title", "") if not title: - return self.create_text_message('Please provide a title for the ppt') - - model = tool_parameters.get('model', 'aippt') + return self.create_text_message("Please provide a title for the ppt") + + model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message('Please provide a model for the ppt') - - outline = tool_parameters.get('outline', '') + return self.create_text_message("Please provide a model for the ppt") + + outline = tool_parameters.get("outline", "") # create task task_id = self._create_task( - type=self._task_type_map['auto' if not outline else 'markdown'], + type=self._task_type_map["auto" if not outline else "markdown"], title=title, content=outline, - user_id=user_id + user_id=user_id, ) # get suit - color = tool_parameters.get('color') - style = tool_parameters.get('style') + color = tool_parameters.get("color") + style = tool_parameters.get("style") - if color == '__default__': - color_id = '' + if color == "__default__": + color_id = "" else: - color_id = int(color.split('-')[1]) + color_id = int(color.split("-")[1]) - if style == '__default__': - style_id = '' + if style == "__default__": + style_id = "" else: - style_id = int(style.split('-')[1]) + style_id = int(style.split("-")[1]) suit_id = self._get_suit(style_id=style_id, colour_id=color_id) # generate outline if not outline: - self._generate_outline( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_outline(task_id=task_id, model=model, user_id=user_id) # generate content - self._generate_content( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_content(task_id=task_id, model=model, user_id=user_id) # generate ppt - _, ppt_url = self._generate_ppt( - task_id=task_id, - suit_id=suit_id, - user_id=user_id - ) + _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message('''the ppt has been created successfully,''' - f'''the ppt url is {ppt_url}''' - '''please give the ppt url to user and direct user to download it.''') + return self.create_text_message( + """the ppt has been created successfully,""" + f"""the ppt url is {ppt_url}""" + """please give the ppt url to user and direct user to download it.""" + ) def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: """ @@ -119,129 +110,121 @@ class AIPPTGenerateTool(BuiltinTool): :return: the task ID """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), + str(self._api_base_url / "ai" / "chat" / "v2" / "task"), headers=headers, - files={ - 'type': ('', str(type)), - 'title': ('', title), - 'content': ('', content) - } + files={"type": ("", str(type)), "title": ("", title), "content": ("", content)}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to create task: {response.get("msg")}') - return response.get('data', {}).get('id') - + return response.get("data", {}).get("id") + def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "outline" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "outline" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - outline = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + outline = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - outline += data.get('content', '') + outline += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate outline: {data}') - + elif event in {"error", "filter"}: + raise Exception(f"Failed to generate outline: {data}") + return outline - + def _generate_content(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "content" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "content" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - if model == 'aippt': - content = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + if model == "aippt": + content = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - content += data.get('content', '') + content += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate content: {data}') - + elif event in {"error", "filter"}: + raise Exception(f"Failed to generate content: {data}") + return content - elif model == 'wenxin': + elif model == "wenxin": response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate content: {response.get("msg")}') - - return response.get('data', '') - - return '' + + return response.get("data", "") + + return "" def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: """ @@ -252,83 +235,73 @@ class AIPPTGenerateTool(BuiltinTool): :return: the cover url of the ppt and the ppt url """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'design' / 'v2' / 'save'), + str(self._api_base_url / "design" / "v2" / "save"), headers=headers, - data={ - 'task_id': task_id, - 'template_id': suit_id - } + data={"task_id": task_id, "template_id": suit_id}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - id = response.get('data', {}).get('id') - cover_url = response.get('data', {}).get('cover_url') + + id = response.get("data", {}).get("id") + cover_url = response.get("data", {}).get("cover_url") response = post( - str(self._api_base_url / 'download' / 'export' / 'file'), + str(self._api_base_url / "download" / "export" / "file"), headers=headers, - data={ - 'id': id, - 'format': 'ppt', - 'files_to_zip': False, - 'edit': True - } + data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - export_code = response.get('data') + + export_code = response.get("data") if not export_code: - raise Exception('Failed to generate ppt, the export code is empty') - + raise Exception("Failed to generate ppt, the export code is empty") + current_iteration = 0 while current_iteration < 50: # get ppt url response = post( - str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), + str(self._api_base_url / "download" / "export" / "file" / "result"), headers=headers, - data={ - 'task_key': export_code - } + data={"task_key": export_code}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - if response.get('msg') == '导出中': + + if response.get("msg") == "导出中": current_iteration += 1 sleep(2) continue - - ppt_url = response.get('data', []) + + ppt_url = response.get("data", []) if len(ppt_url) == 0: - raise Exception('Failed to generate ppt, the ppt url is empty') - + raise Exception("Failed to generate ppt, the ppt url is empty") + return cover_url, ppt_url[0] - - raise Exception('Failed to generate ppt, the export is timeout') - + + raise Exception("Failed to generate ppt, the export is timeout") + @classmethod def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: """ @@ -337,53 +310,43 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: the API token """ - access_key = credentials['aippt_access_key'] - secret_key = credentials['aippt_secret_key'] + access_key = credentials["aippt_access_key"] + secret_key = credentials["aippt_secret_key"] - cache_key = f'{access_key}#@#{user_id}' + cache_key = f"{access_key}#@#{user_id}" with cls._api_token_cache_lock: # clear expired tokens now = time() for key in list(cls._api_token_cache.keys()): - if cls._api_token_cache[key]['expire'] < now: + if cls._api_token_cache[key]["expire"] < now: del cls._api_token_cache[key] if cache_key in cls._api_token_cache: - return cls._api_token_cache[cache_key]['token'] - + return cls._api_token_cache[cache_key]["token"] + # get token headers = { - 'x-api-key': access_key, - 'x-timestamp': str(int(now)), - 'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) + "x-api-key": access_key, + "x-timestamp": str(int(now)), + "x-signature": cls._calculate_sign(access_key, secret_key, int(now)), } - param = { - 'uid': user_id, - 'channel': '' - } + param = {"uid": user_id, "channel": ""} - response = get( - str(cls._api_base_url / 'grant' / 'token'), - params=param, - headers=headers - ) + response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') + raise Exception(f"Failed to connect to aippt: {response.text}") response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - token = response.get('data', {}).get('token') - expire = response.get('data', {}).get('time_expire') + + token = response.get("data", {}).get("token") + expire = response.get("data", {}).get("time_expire") with cls._api_token_cache_lock: - cls._api_token_cache[cache_key] = { - 'token': token, - 'expire': now + expire - } + cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire} return token @@ -391,11 +354,9 @@ class AIPPTGenerateTool(BuiltinTool): def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode('utf-8'), - msg=f'GET@/api/grant/token/@{timestamp}'.encode(), - digestmod=sha1 + key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 ).digest() - ).decode('utf-8') + ).decode("utf-8") @classmethod def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: @@ -408,47 +369,46 @@ class AIPPTGenerateTool(BuiltinTool): # clear expired styles now = time() for key in list(cls._style_cache.keys()): - if cls._style_cache[key]['expire'] < now: + if cls._style_cache[key]["expire"] < now: del cls._style_cache[key] key = f'{credentials["aippt_access_key"]}#@#{user_id}' if key in cls._style_cache: - return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"] headers = { - 'x-channel': '', - 'x-api-key': credentials['aippt_access_key'], - 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) + "x-channel": "", + "x-api-key": credentials["aippt_access_key"], + "x-token": cls._get_api_token(credentials=credentials, user_id=user_id), } - response = get( - str(cls._api_base_url / 'template_component' / 'suit' / 'select'), - headers=headers - ) + response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - colors = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('name'), - 'en_name': item.get('en_name', item.get('name')), - } for item in response.get('data', {}).get('colour') or []] - styles = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('title'), - } for item in response.get('data', {}).get('suit_style') or []] + + colors = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("name"), + "en_name": item.get("en_name", item.get("name")), + } + for item in response.get("data", {}).get("colour") or [] + ] + styles = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("title"), + } + for item in response.get("data", {}).get("suit_style") or [] + ] with cls._style_cache_lock: - cls._style_cache[key] = { - 'colors': colors, - 'styles': styles, - 'expire': now + 60 * 60 - } + cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60} return colors, styles @@ -459,44 +419,39 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): - raise Exception('Please provide aippt credentials') + if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + raise Exception("Please provide aippt credentials") return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) - + def _get_suit(self, style_id: int, colour_id: int) -> int: """ Get suit """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), } response = get( - str(self._api_base_url / 'template_component' / 'suit' / 'search'), + str(self._api_base_url / "template_component" / "suit" / "search"), headers=headers, - params={ - 'style_id': style_id, - 'colour_id': colour_id, - 'page': 1, - 'page_size': 1 - } + params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - if len(response.get('data', {}).get('list') or []) > 0: - return response.get('data', {}).get('list')[0].get('id') - - raise Exception('Failed to get suit, the suit does not exist, please check the style and color') - + + if len(response.get("data", {}).get("list") or []) > 0: + return response.get("data", {}).get("list")[0].get("id") + + raise Exception("Failed to get suit, the suit does not exist, please check the style and color") + def get_runtime_parameters(self) -> list[ToolParameter]: """ Get runtime parameters @@ -504,43 +459,40 @@ class AIPPTGenerateTool(BuiltinTool): Override this method to add runtime parameters to the tool. """ try: - colors, styles = self.get_styles(user_id='__dify_system__') + colors, styles = self.get_styles(user_id="__dify_system__") except Exception as e: - colors, styles = [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ], [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ] + colors, styles = ( + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + ) return [ ToolParameter( - name='color', - label=I18nObject(zh_Hans='颜色', en_US='Color'), - human_description=I18nObject(zh_Hans='颜色', en_US='Color'), + name="color", + label=I18nObject(zh_Hans="颜色", en_US="Color"), + human_description=I18nObject(zh_Hans="颜色", en_US="Color"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=colors[0]['id'], + default=colors[0]["id"], options=[ ToolParameterOption( - value=color['id'], - label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) - ) for color in colors - ] + value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"]) + ) + for color in colors + ], ), ToolParameter( - name='style', - label=I18nObject(zh_Hans='风格', en_US='Style'), - human_description=I18nObject(zh_Hans='风格', en_US='Style'), + name="style", + label=I18nObject(zh_Hans="风格", en_US="Style"), + human_description=I18nObject(zh_Hans="风格", en_US="Style"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=styles[0]['id'], + default=styles[0]["id"], options=[ - ToolParameterOption( - value=style['id'], - label=I18nObject(zh_Hans=style['name'], en_US=style['name']) - ) for style in styles - ] + ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"])) + for style in styles + ], ), - ] \ No newline at end of file + ] diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py index 01f2acfb5b..a84630e5aa 100644 --- a/api/core/tools/provider/builtin/alphavantage/alphavantage.py +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -13,7 +13,7 @@ class AlphaVantageProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "code": "AAPL", # Apple Inc. }, diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py index 5c379b746d..d06611acd0 100644 --- a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -9,17 +9,16 @@ ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query" class QueryStockTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - stock_code = tool_parameters.get('code', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + stock_code = tool_parameters.get("code", "") if not stock_code: - return self.create_text_message('Please tell me your stock code') + return self.create_text_message("Please tell me your stock code") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Alpha Vantage API key is required.") params = { @@ -27,7 +26,7 @@ class QueryStockTool(BuiltinTool): "symbol": stock_code, "outputsize": "compact", "datatype": "json", - "apikey": self.runtime.credentials['api_key'] + "apikey": self.runtime.credentials["api_key"], } response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) response.raise_for_status() @@ -35,15 +34,15 @@ class QueryStockTool(BuiltinTool): return self.create_json_message(result) def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: - result = response.get('Time Series (Daily)', {}) + result = response.get("Time Series (Daily)", {}) if not result: return {} stock_result = {} for k, v in result.items(): stock_result[k] = {} - stock_result[k]['open'] = v.get('1. open') - stock_result[k]['high'] = v.get('2. high') - stock_result[k]['low'] = v.get('3. low') - stock_result[k]['close'] = v.get('4. close') - stock_result[k]['volume'] = v.get('5. volume') + stock_result[k]["open"] = v.get("1. open") + stock_result[k]["high"] = v.get("2. high") + stock_result[k]["low"] = v.get("3. low") + stock_result[k]["close"] = v.get("4. close") + stock_result[k]["volume"] = v.get("5. volume") return stock_result diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 707fc69be3..ebb2d1a8c4 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -11,11 +11,10 @@ class ArxivProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index ce28373880..2d65ba2d6f 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -8,6 +8,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool logger = logging.getLogger(__name__) + + class ArxivAPIWrapper(BaseModel): """Wrapper around ArxivAPI. @@ -86,11 +88,13 @@ class ArxivAPIWrapper(BaseModel): class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") - + + class ArxivSearchTool(BuiltinTool): """ A tool for searching articles on Arxiv. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Arxiv search tool with the given user ID and tool parameters. @@ -100,15 +104,16 @@ class ArxivSearchTool(BuiltinTool): tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter. Returns: - ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, + which can be a single message or a list of messages. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + arxiv = ArxivAPIWrapper() - + response = arxiv.run(query) - + return self.create_text_message(self.summary(user_id=user_id, content=response)) diff --git a/api/core/tools/provider/builtin/aws/aws.py b/api/core/tools/provider/builtin/aws/aws.py index 13ede96015..f81b5dbd27 100644 --- a/api/core/tools/provider/builtin/aws/aws.py +++ b/api/core/tools/provider/builtin/aws/aws.py @@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "sagemaker_endpoint" : "", + "sagemaker_endpoint": "", "query": "misaka mikoto", - "candidate_texts" : "hello$$$hello world", - "topk" : 5, - "aws_region" : "" + "candidate_texts": "hello$$$hello world", + "topk": 5, + "aws_region": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 06fcf8a453..a04f5c0fe9 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -12,6 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class GuardrailParameters(BaseModel): guardrail_id: str = Field(..., description="The identifier of the guardrail") guardrail_version: str = Field(..., description="The version of the guardrail") @@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel): text: str = Field(..., description="The text to apply the guardrail to") aws_region: str = Field(..., description="AWS region for the Bedrock client") + class ApplyGuardrailTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the ApplyGuardrail tool """ try: # Validate and parse input parameters params = GuardrailParameters(**tool_parameters) - + # Initialize AWS client - bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region) + bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region) # Apply guardrail response = bedrock_client.apply_guardrail( guardrailIdentifier=params.guardrail_id, guardrailVersion=params.guardrail_version, source=params.source, - content=[{"text": {"text": params.text}}] + content=[{"text": {"text": params.text}}], ) - + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") # Check for empty response if not response: return self.create_text_message(text="Received empty response from AWS Bedrock.") - + # Process the result action = response.get("action", "No action specified") outputs = response.get("outputs", []) @@ -58,9 +59,12 @@ class ApplyGuardrailTool(BuiltinTool): formatted_assessments = [] for assessment in assessments: for policy_type, policy_data in assessment.items(): - if isinstance(policy_data, dict) and 'topics' in policy_data: - for topic in policy_data['topics']: - formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}") + if isinstance(policy_data, dict) and "topics" in policy_data: + for topic in policy_data["topics"]: + formatted_assessments.append( + f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}," + f" Action: {topic['action']}" + ) else: formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}") @@ -68,19 +72,19 @@ class ApplyGuardrailTool(BuiltinTool): result += f"Output: {output}\n " if formatted_assessments: result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n " -# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" + # result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" return self.create_text_message(text=result) except BotoCoreError as e: - error_message = f'AWS service error: {str(e)}' + error_message = f"AWS service error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except json.JSONDecodeError as e: - error_message = f'JSON parsing error: {str(e)}' + error_message = f"JSON parsing error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except Exception as e: - error_message = f'An unexpected error occurred: {str(e)}' + error_message = f"An unexpected error occurred: {str(e)}" logger.error(error_message, exc_info=True) - return self.create_text_message(text=error_message) \ No newline at end of file + return self.create_text_message(text=error_message) diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 005ba3deb5..48755753ac 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): - msg = { - "src_content":text_content, - "src_lang": src_lang, - "dest_lang":dest_lang, + msg = { + "src_content": text_content, + "src_lang": src_lang, + "dest_lang": dest_lang, "dictionary_id": dictionary_name, - "request_type" : request_type, - "model_id" : model_id + "request_type": request_type, + "model_id": model_id, } - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("unicode_escape") return response_str - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") line = 1 - text_content = tool_parameters.get('text_content', '') + text_content = tool_parameters.get("text_content", "") if not text_content: - return self.create_text_message('Please input text_content') - + return self.create_text_message("Please input text_content") + line = 2 - src_lang = tool_parameters.get('src_lang', '') + src_lang = tool_parameters.get("src_lang", "") if not src_lang: - return self.create_text_message('Please input src_lang') - + return self.create_text_message("Please input src_lang") + line = 3 - dest_lang = tool_parameters.get('dest_lang', '') + dest_lang = tool_parameters.get("dest_lang", "") if not dest_lang: - return self.create_text_message('Please input dest_lang') - + return self.create_text_message("Please input dest_lang") + line = 4 - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - + return self.create_text_message("Please input lambda_name") + line = 5 - request_type = tool_parameters.get('request_type', '') + request_type = tool_parameters.get("request_type", "") if not request_type: - return self.create_text_message('Please input request_type') - + return self.create_text_message("Please input request_type") + line = 6 - model_id = tool_parameters.get('model_id', '') + model_id = tool_parameters.get("model_id", "") if not model_id: - return self.create_text_message('Please input model_id') + return self.create_text_message("Please input model_id") line = 7 - dictionary_name = tool_parameters.get('dictionary_name', '') + dictionary_name = tool_parameters.get("dictionary_name", "") if not dictionary_name: - return self.create_text_message('Please input dictionary_name') - - result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name) + return self.create_text_message("Please input dictionary_name") + + result = self._invoke_lambda( + text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name + ) return self.create_text_message(text=result) except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml index a35c9f49fb..3bb133c7ec 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml @@ -10,7 +10,7 @@ description: human: en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag - pt_BR: A util tools for LLM translation, specfic Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag + pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag llm: A util tools for translation. parameters: - name: text_content diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index bb7f6840b8..f43f3b6fe0 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str: - msg = { - "body": yaml_content - } + msg = {"body": yaml_content} logger.info(json.dumps(msg)) - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("utf-8") resp_json = json.loads(response_str) logger.info(resp_json) - if resp_json['statusCode'] != 200: + if resp_json["statusCode"] != 200: raise Exception(f"Invalid status code: {response_str}") - return resp_json['body'] + return resp_json["body"] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region + aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") - yaml_content = tool_parameters.get('yaml_content', '') + yaml_content = tool_parameters.get("yaml_content", "") if not yaml_content: - return self.create_text_message('Please input yaml_content') + return self.create_text_message("Please input yaml_content") - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}') - + return self.create_text_message("Please input lambda_name") + logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}") + result = self._invoke_lambda(lambda_name, yaml_content) logger.debug(result) - + return self.create_text_message(result) except Exception as e: - return self.create_text_message(f'Exception: {str(e)}') + return self.create_text_message(f"Exception: {str(e)}") - console_handler.flush() \ No newline at end of file + console_handler.flush() diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 2b3a3eaad6..bffcd058b5 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -1,4 +1,5 @@ import json +import operator from typing import Any, Union import boto3 @@ -9,37 +10,33 @@ from core.tools.tool.builtin_tool import BuiltinTool class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - topk:int = None + sagemaker_endpoint: str = None + topk: int = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -47,25 +44,25 @@ class SageMakerReRankTool(BuiltinTool): line = 1 if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 if not self.topk: - self.topk = tool_parameters.get('topk', 5) + self.topk = tool_parameters.get("topk", 5) line = 3 - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + line = 4 - candidate_texts = tool_parameters.get('candidate_texts') + candidate_texts = tool_parameters.get("candidate_texts") if not candidate_texts: - return self.create_text_message('Please input candidate_texts') - + return self.create_text_message("Please input candidate_texts") + line = 5 candidate_docs = json.loads(candidate_texts) - docs = [ item.get('content') for item in candidate_docs ] + docs = [item.get("content") for item in candidate_docs] line = 6 scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) @@ -75,10 +72,10 @@ class SageMakerReRankTool(BuiltinTool): candidate_docs[idx]["score"] = scores[idx] line = 8 - sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 - return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ] - + return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index a100e62230..bceeaab745 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -14,82 +14,88 @@ class TTSModelType(Enum): CloneVoice_CrossLingual = "CloneVoice_CrossLingual" InstructVoice = "InstructVoice" + class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - s3_client : Any = None - comprehend_client : Any = None + sagemaker_endpoint: str = None + s3_client: Any = None + comprehend_client: Any = None - def _detect_lang_code(self, content:str, map_dict:dict=None): - map_dict = { - "zh" : "<|zh|>", - "en" : "<|en|>", - "ja" : "<|jp|>", - "zh-TW" : "<|yue|>", - "ko" : "<|ko|>" - } + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} response = self.comprehend_client.detect_dominant_language(Text=content) - language_code = response['Languages'][0]['LanguageCode'] - return map_dict.get(language_code, '<|zh|>') + language_code = response["Languages"][0]["LanguageCode"] + return map_dict.get(language_code, "<|zh|>") - def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): if model_type == TTSModelType.PresetVoice.value and model_role: - return { "tts_text" : content_text, "role" : model_role } + return {"tts_text": content_text, "role": model_role} if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: - return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } - if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: lang_tag = self._detect_lang_code(content_text) - return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } - if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: - return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} raise RuntimeError(f"Invalid params for {model_type}") - def _invoke_sagemaker(self, payload:dict, endpoint:str): + def _invoke_sagemaker(self, payload: dict, endpoint: str): response_model = self.sagemaker_client.invoke_endpoint( EndpointName=endpoint, Body=json.dumps(payload), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) return json_obj - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - self.comprehend_client = boto3.client('comprehend') + self.comprehend_client = boto3.client("comprehend") if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") - tts_text = tool_parameters.get('tts_text') - tts_infer_type = tool_parameters.get('tts_infer_type') + tts_text = tool_parameters.get("tts_text") + tts_infer_type = tool_parameters.get("tts_infer_type") - voice = tool_parameters.get('voice') - mock_voice_audio = tool_parameters.get('mock_voice_audio') - mock_voice_text = tool_parameters.get('mock_voice_text') - voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt') - payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt) + voice = tool_parameters.get("voice") + mock_voice_audio = tool_parameters.get("mock_voice_audio") + mock_voice_text = tool_parameters.get("mock_voice_text") + voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt") + payload = self._build_tts_payload( + tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt + ) result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) - return self.create_text_message(text=result['s3_presign_url']) - + return self.create_text_message(text=result["s3_presign_url"]) + except Exception as e: - return self.create_text_message(f'Exception {str(e)}') \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 2981a54d3c..1fab0d03a2 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -13,12 +13,8 @@ class AzureDALLEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "square", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 2ffdd38b72..cfa3cfb092 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -9,47 +9,48 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ client = AzureOpenAI( - api_version=self.runtime.credentials['azure_openai_api_version'], - azure_endpoint=self.runtime.credentials['azure_openai_base_url'], - api_key=self.runtime.credentials['azure_openai_api_key'], + api_version=self.runtime.credentials["azure_openai_api_version"], + azure_endpoint=self.runtime.credentials["azure_openai_base_url"], + api_key=self.runtime.credentials["azure_openai_api_key"], ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in {"standard", "hd"}: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in {"natural", "vivid"}: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} # call openapi dalle3 - model = self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials["azure_openai_api_model_name"] response = client.images.generate( prompt=prompt, model=model, @@ -58,21 +59,25 @@ class DallE3Tool(BuiltinTool): extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) - result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + ) + result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index f85a5ed472..8bed2c556c 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -8,142 +8,135 @@ from core.tools.tool.builtin_tool import BuiltinTool class BingSearchTool(BuiltinTool): - url: str = 'https://api.bing.microsoft.com/v7.0/search' + url: str = "https://api.bing.microsoft.com/v7.0/search" - def _invoke_bing(self, - user_id: str, - server_url: str, - subscription_key: str, query: str, limit: int, - result_type: str, market: str, lang: str, - filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke_bing( + self, + user_id: str, + server_url: str, + subscription_key: str, + query: str, + limit: int, + result_type: str, + market: str, + lang: str, + filters: list[str], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke bing search + invoke bing search """ - market_code = f'{lang}-{market}' - accept_language = f'{lang},{market_code};q=0.9' - headers = { - 'Ocp-Apim-Subscription-Key': subscription_key, - 'Accept-Language': accept_language - } + market_code = f"{lang}-{market}" + accept_language = f"{lang},{market_code};q=0.9" + headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language} query = quote(query) server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - - response = response.json() - search_results = response['webPages']['value'][:limit] if 'webPages' in response else [] - related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else [] - entities = response['entities']['value'] if 'entities' in response else [] - news = response['news']['value'] if 'news' in response else [] - computation = response['computation']['value'] if 'computation' in response else None + raise Exception(f"Error {response.status_code}: {response.text}") - if result_type == 'link': + response = response.json() + search_results = response["webPages"]["value"][:limit] if "webPages" in response else [] + related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else [] + entities = response["entities"]["value"] if "entities" in response else [] + news = response["news"]["value"] if "news" in response else [] + computation = response["computation"]["value"] if "computation" in response else None + + if result_type == "link": results = [] if search_results: for result in search_results: url = f': {result["url"]}' if "url" in result else "" - results.append(self.create_text_message( - text=f'{result["name"]}{url}' - )) - + results.append(self.create_text_message(text=f'{result["name"]}{url}')) if entities: for entity in entities: url = f': {entity["url"]}' if "url" in entity else "" - results.append(self.create_text_message( - text=f'{entity.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}')) if news: for news_item in news: url = f': {news_item["url"]}' if "url" in news_item else "" - results.append(self.create_text_message( - text=f'{news_item.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}')) if related_searches: for related in related_searches: url = f': {related["displayText"]}' if "displayText" in related else "" - results.append(self.create_text_message( - text=f'{related.get("displayText", "")}{url}' - )) - + results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) + return results else: # construct text - text = '' + text = "" if search_results: for i, result in enumerate(search_results): - text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n' + text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n' - if computation and 'expression' in computation and 'value' in computation: - text += '\nComputation:\n' + if computation and "expression" in computation and "value" in computation: + text += "\nComputation:\n" text += f'{computation["expression"]} = {computation["value"]}\n' if entities: - text += '\nEntities:\n' + text += "\nEntities:\n" for entity in entities: url = f'- {entity["url"]}' if "url" in entity else "" text += f'{entity.get("name", "")}{url}\n' if news: - text += '\nNews:\n' + text += "\nNews:\n" for news_item in news: url = f'- {news_item["url"]}' if "url" in news_item else "" text += f'{news_item.get("name", "")}{url}\n' if related_searches: - text += '\n\nRelated Searches:\n' + text += "\n\nRelated Searches:\n" for related in related_searches: url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else "" text += f'{related.get("displayText", "")}{url}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: - key = credentials.get('subscription_key') + key = credentials.get("subscription_key") if not key: - raise Exception('subscription_key is required') - - server_url = credentials.get('server_url') + raise Exception("subscription_key is required") + + server_url = credentials.get("server_url") if not server_url: server_url = self.url - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' + raise Exception("query is required") - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if credentials.get('allow_entities', False): - filter.append('Entities') + if credentials.get("allow_entities", False): + filter.append("Entities") - if credentials.get('allow_computation', False): - filter.append('Computation') + if credentials.get("allow_computation", False): + filter.append("Computation") - if credentials.get('allow_news', False): - filter.append('News') + if credentials.get("allow_news", False): + filter.append("News") - if credentials.get('allow_related_searches', False): - filter.append('RelatedSearches') + if credentials.get("allow_related_searches", False): + filter.append("RelatedSearches") - if credentials.get('allow_web_pages', False): - filter.append('WebPages') + if credentials.get("allow_web_pages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + self._invoke_bing( - user_id='test', + user_id="test", server_url=server_url, subscription_key=key, query=query, @@ -151,50 +144,51 @@ class BingSearchTool(BuiltinTool): result_type=result_type, market=market, lang=lang, - filters=filter + filters=filter, ) - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - key = self.runtime.credentials.get('subscription_key', None) + key = self.runtime.credentials.get("subscription_key", None) if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) + raise Exception("subscription_key is required") + + server_url = self.runtime.credentials.get("server_url", None) if not server_url: server_url = self.url - - query = tool_parameters.get('query') + + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + raise Exception("query is required") + + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') + if tool_parameters.get("enable_computation", False): + filter.append("Computation") + if tool_parameters.get("enable_entities", False): + filter.append("Entities") + if tool_parameters.get("enable_news", False): + filter.append("News") + if tool_parameters.get("enable_related_search", False): + filter.append("RelatedSearches") + if tool_parameters.get("enable_webpages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + return self._invoke_bing( user_id=user_id, server_url=server_url, @@ -204,5 +198,5 @@ class BingSearchTool(BuiltinTool): result_type=result_type, market=market, lang=lang, - filters=filter - ) \ No newline at end of file + filters=filter, + ) diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e5eada80ee..c24ee67334 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -13,11 +13,10 @@ class BraveProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index 21cbf2c7da..94a4d92844 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -37,7 +37,7 @@ class BraveSearchWrapper(BaseModel): for item in web_search_results ] return json.dumps(final_results) - + def _search_request(self, query: str) -> list[dict]: headers = { "X-Subscription-Token": self.api_key, @@ -55,6 +55,7 @@ class BraveSearchWrapper(BaseModel): return response.json().get("web", {}).get("results", []) + class BraveSearch(BaseModel): """Tool that queries the BraveSearch.""" @@ -67,9 +68,7 @@ class BraveSearch(BaseModel): search_wrapper: BraveSearchWrapper @classmethod - def from_api_key( - cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any - ) -> "BraveSearch": + def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch": """Create a tool from an api key. Args: @@ -90,6 +89,7 @@ class BraveSearch(BaseModel): """Use the tool.""" return self.search_wrapper.run(query) + class BraveSearchTool(BuiltinTool): """ Tool for performing a search using Brave search engine. @@ -106,12 +106,12 @@ class BraveSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') - count = tool_parameters.get('count', 3) - api_key = self.runtime.credentials['brave_search_api_key'] + query = tool_parameters.get("query", "") + count = tool_parameters.get("count", 3) + api_key = self.runtime.credentials["brave_search_api_key"] if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) @@ -121,4 +121,3 @@ class BraveSearchTool(BuiltinTool): return self.create_text_message(f"No results found for '{query}' in Tavily") else: return self.create_text_message(text=results) - diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 0865bc700a..8a24d33428 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -7,16 +7,34 @@ from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController # use a business theme -plt.style.use('seaborn-v0_8-darkgrid') -plt.rcParams['axes.unicode_minus'] = False +plt.style.use("seaborn-v0_8-darkgrid") +plt.rcParams["axes.unicode_minus"] = False + def init_fonts(): fonts = findSystemFonts() popular_unicode_fonts = [ - 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif', - 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans', - 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono' + "Arial Unicode MS", + "DejaVu Sans", + "DejaVu Sans Mono", + "DejaVu Serif", + "FreeMono", + "FreeSans", + "FreeSerif", + "Liberation Mono", + "Liberation Sans", + "Liberation Serif", + "Noto Mono", + "Noto Sans", + "Noto Serif", + "Open Sans", + "Roboto", + "Source Code Pro", + "Source Sans Pro", + "Source Serif Pro", + "Ubuntu", + "Ubuntu Mono", ] supported_fonts = [] @@ -25,21 +43,23 @@ def init_fonts(): try: font = TTFont(font_path) # get family name - family_name = font['name'].getName(1, 3, 1).toUnicode() + family_name = font["name"].getName(1, 3, 1).toUnicode() if family_name in popular_unicode_fonts: supported_fonts.append(family_name) except: pass - plt.rcParams['font.family'] = 'sans-serif' + plt.rcParams["font.family"] = "sans-serif" # sort by order of popular_unicode_fonts for font in popular_unicode_fonts: if font in supported_fonts: - plt.rcParams['font.sans-serif'] = font + plt.rcParams["font.sans-serif"] = font break - + + init_fonts() + class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: @@ -48,11 +68,10 @@ class ChartProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "data": "1,3,5,7,9,2,4,6,8,10", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 749ec761c6..3a47c0cfc0 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -8,12 +8,13 @@ from core.tools.tool.builtin_tool import BuiltinTool class BarChartTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -21,29 +22,27 @@ class BarChartTool(BuiltinTool): else: data = [float(i) for i in data] - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.bar(axis, data) else: ax.bar(range(len(data)), data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the bar chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the bar chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 608bd6623c..39e8caac7e 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -8,18 +8,19 @@ from core.tools.tool.builtin_tool import BuiltinTool class LinearChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None @@ -32,20 +33,18 @@ class LinearChartTool(BuiltinTool): flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.plot(axis, data) else: ax.plot(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the linear chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the linear chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index 4c551229e9..2c3b8a733e 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -8,15 +8,16 @@ from core.tools.tool.builtin_tool import BuiltinTool class PieChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') - categories = tool_parameters.get('categories') or None + return self.create_text_message("Please input data") + data = data.split(";") + categories = tool_parameters.get("categories") or None # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -27,7 +28,7 @@ class PieChartTool(BuiltinTool): flg, ax = plt.subplots() if categories: - categories = categories.split(';') + categories = categories.split(";") if len(categories) != len(data): categories = None @@ -37,12 +38,11 @@ class PieChartTool(BuiltinTool): ax.pie(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the pie chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) - ] \ No newline at end of file + self.create_text_message("the pie chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), + ] diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 37645bf0d0..632c9fc7f1 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -8,15 +8,15 @@ from core.tools.tool.builtin_tool import BuiltinTool class SimpleCode(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ - invoke simple code + invoke simple code """ - language = tool_parameters.get('language', CodeLanguage.PYTHON3) - code = tool_parameters.get('code', '') + language = tool_parameters.get("language", CodeLanguage.PYTHON3) + code = tool_parameters.get("code", "") - if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: - raise ValueError(f'Only python3 and javascript are supported, not {language}') - - result = CodeExecutor.execute_code(language, '', code) + if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: + raise ValueError(f"Only python3 and javascript are supported, not {language}") - return self.create_text_message(result) \ No newline at end of file + result = CodeExecutor.execute_code(language, "", code) + + return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/cogview/cogview.py b/api/core/tools/provider/builtin/cogview/cogview.py index 801817ec06..6941ce8649 100644 --- a/api/core/tools/provider/builtin/cogview/cogview.py +++ b/api/core/tools/provider/builtin/cogview/cogview.py @@ -1,4 +1,5 @@ -""" Provide the input parameters type for the cogview provider class """ +"""Provide the input parameters type for the cogview provider class""" + from typing import Any from core.tools.errors import ToolProviderCredentialValidationError @@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class COGVIEWProvider(BuiltinToolProviderController): - """ cogview provider """ + """cogview provider""" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CogView3Tool().fork_tool_runtime( @@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", "size": "square", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) from e - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 89ffcf3347..9039708588 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class CogView3Tool(BuiltinTool): - """ CogView3 Tool """ + """CogView3 Tool""" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke CogView3 tool """ client = ZhipuAI( - base_url=self.runtime.credentials['zhipuai_base_url'], - api_key=self.runtime.credentials['zhipuai_api_key'], + base_url=self.runtime.credentials["zhipuai_base_url"], + api_key=self.runtime.credentials["zhipuai_api_key"], ) size_mapping = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = size_mapping[tool_parameters.get('size', 'square')] + size = size_mapping[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in {"standard", "hd"}: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in {"natural", "vivid"}: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} response = client.images.generations( prompt=prompt, model="cogview-3", @@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool): extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py index 404e483e0d..8ba3c1b48a 100644 --- a/api/core/tools/provider/builtin/crossref/crossref.py +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -11,9 +11,9 @@ class CrossRefProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "doi": '10.1007/s00894-022-05373-8', + "doi": "10.1007/s00894-022-05373-8", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py index a43c0989e4..746139dd69 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_doi.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -11,15 +11,18 @@ class CrossRefQueryDOITool(BuiltinTool): """ Tool for querying the metadata of a publication using its DOI. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - doi = tool_parameters.get('doi') + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get("doi") if not doi: - raise ToolParameterValidationError('doi is required.') + raise ToolParameterValidationError("doi is required.") # doc: https://github.com/CrossRef/rest-api-doc url = f"https://api.crossref.org/works/{doi}" response = requests.get(url) response.raise_for_status() response = response.json() - message = response.get('message', {}) + message = response.get("message", {}) return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py index 946aa6dc94..e245238183 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_title.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -12,16 +12,16 @@ def convert_time_str_to_seconds(time_str: str) -> int: Convert a time string to seconds. example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 """ - time_str = time_str.lower().strip().replace(' ', '') + time_str = time_str.lower().strip().replace(" ", "") seconds = 0 - if 'h' in time_str: - hours, time_str = time_str.split('h') + if "h" in time_str: + hours, time_str = time_str.split("h") seconds += int(hours) * 3600 - if 'm' in time_str: - minutes, time_str = time_str.split('m') + if "m" in time_str: + minutes, time_str = time_str.split("m") seconds += int(minutes) * 60 - if 's' in time_str: - seconds += int(time_str.replace('s', '')) + if "s" in time_str: + seconds += int(time_str.replace("s", "")) return seconds @@ -30,6 +30,7 @@ class CrossRefQueryTitleAPI: Tool for querying the metadata of a publication using its title. Crossref API doc: https://github.com/CrossRef/rest-api-doc """ + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" rate_limit: int = 50 rate_interval: float = 1 @@ -38,7 +39,15 @@ class CrossRefQueryTitleAPI: def __init__(self, mailto: str): self.mailto = mailto - def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def _query( + self, + query: str, + rows: int = 5, + offset: int = 0, + sort: str = "relevance", + order: str = "desc", + fuzzy_query: bool = False, + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -47,33 +56,37 @@ class CrossRefQueryTitleAPI: :param order: the sort order :param fuzzy_query: whether to return all items that match the query """ - url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto) + url = self.query_url_template.format( + query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto + ) response = requests.get(url) response.raise_for_status() - rate_limit = int(response.headers['x-ratelimit-limit']) + rate_limit = int(response.headers["x-ratelimit-limit"]) # convert time string to seconds - rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval']) + rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"]) self.rate_limit = rate_limit self.rate_interval = rate_interval response = response.json() - if response['status'] != 'ok': + if response["status"] != "ok": return [] - message = response['message'] + message = response["message"] if fuzzy_query: # fuzzy query return all items - return message['items'] + return message["items"] else: - for paper in message['items']: - title = paper['title'][0] + for paper in message["items"]: + title = paper["title"][0] if title.lower() != query.lower(): continue return [paper] return [] - def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def query( + self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -89,7 +102,14 @@ class CrossRefQueryTitleAPI: results = [] for i in range(query_times): - result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query) + result = self._query( + query, + rows=self.rate_limit, + offset=i * self.rate_limit, + sort=sort, + order=order, + fuzzy_query=fuzzy_query, + ) if fuzzy_query: results.extend(result) else: @@ -107,13 +127,16 @@ class CrossRefQueryTitleTool(BuiltinTool): """ Tool for querying the metadata of a publication using its title. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query') - fuzzy_query = tool_parameters.get('fuzzy_query', False) - rows = tool_parameters.get('rows', 3) - sort = tool_parameters.get('sort', 'relevance') - order = tool_parameters.get('order', 'desc') - mailto = self.runtime.credentials['mailto'] + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query") + fuzzy_query = tool_parameters.get("fuzzy_query", False) + rows = tool_parameters.get("rows", 3) + sort = tool_parameters.get("sort", "relevance") + order = tool_parameters.get("order", "desc") + mailto = self.runtime.credentials["mailto"] result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 1c8019364d..5bd16e49e8 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -13,13 +13,8 @@ class DALLEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "small", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index 9e9f32d429..fbd7397292 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -9,59 +9,58 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE2Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organization_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'small': '256x256', - 'medium': '512x512', - 'large': '1024x1024', + "small": "256x256", + "medium": "512x512", + "large": "1024x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - + return self.create_text_message("Please input prompt") + # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'large')] + size = SIZE_MAPPING[tool_parameters.get("size", "large")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # call openapi dalle2 - response = client.images.generate( - prompt=prompt, - model='dall-e-2', - size=size, - n=n, - response_format='b64_json' - ) + response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json") result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + ) return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 4f5033dd7f..a8c647d71e 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -10,69 +10,64 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organization_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in {"standard", "hd"}: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in {"natural", "vivid"}: + return self.create_text_message("Invalid style") # call openapi dalle3 response = client.images.generate( - prompt=prompt, - model='dall-e-3', - size=size, - n=n, - style=style, - quality=quality, - response_format='b64_json' + prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json" ) result = [] for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) - blob_message = self.create_blob_message(blob=blob_image, - meta={'mime_type': mime_type}, - save_as=self.VARIABLE_KEY.IMAGE.value) + blob_message = self.create_blob_message( + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value + ) result.append(blob_message) return result @@ -86,7 +81,7 @@ class DallE3Tool(BuiltinTool): :return: A tuple containing the MIME type and the decoded image bytes """ if DallE3Tool._is_plain_base64(base64_image): - return 'image/png', base64.b64decode(base64_image) + return "image/png", base64.b64decode(base64_image) else: return DallE3Tool._extract_mime_and_data(base64_image) @@ -98,7 +93,7 @@ class DallE3Tool(BuiltinTool): :param encoded_str: Base64 encoded image string :return: True if the string is plain base64, False otherwise """ - return not encoded_str.startswith('data:image') + return not encoded_str.startswith("data:image") @staticmethod def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: @@ -108,13 +103,13 @@ class DallE3Tool(BuiltinTool): :param encoded_str: Base64 encoded image string with MIME type prefix :return: A tuple containing the MIME type and the decoded image bytes """ - mime_type = encoded_str.split(';')[0].split(':')[1] - image_data_base64 = encoded_str.split(',')[1] + mime_type = encoded_str.split(";")[0].split(":")[1] + image_data_base64 = encoded_str.split(",")[1] decoded_data = base64.b64decode(image_data_base64) return mime_type, decoded_data @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 95d7939d0d..446c1e5489 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -11,7 +11,7 @@ class DevDocsProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "doc": "python~3.12", "topic": "library/code", @@ -19,4 +19,3 @@ class DevDocsProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py index 1a244c5db3..57cf6d7a30 100644 --- a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py +++ b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py @@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel): class SearchDevDocsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the DevDocs search tool with the given user ID and tool parameters. @@ -22,15 +24,16 @@ class SearchDevDocsTool(BuiltinTool): tool_parameters (dict[str, Any]): The parameters for the tool, including 'doc' and 'topic'. Returns: - ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, + which can be a single message or a list of messages. """ - doc = tool_parameters.get('doc', '') - topic = tool_parameters.get('topic', '') + doc = tool_parameters.get("doc", "") + topic = tool_parameters.get("topic", "") if not doc: - return self.create_text_message('Please provide the documentation name.') + return self.create_text_message("Please provide the documentation name.") if not topic: - return self.create_text_message('Please provide the topic path.') + return self.create_text_message("Please provide the topic path.") url = f"https://documents.devdocs.io/{doc}/{topic}.html" response = requests.get(url) @@ -39,4 +42,6 @@ class SearchDevDocsTool(BuiltinTool): content = response.text return self.create_text_message(self.summary(user_id=user_id, content=content)) else: - return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}") \ No newline at end of file + return self.create_text_message( + f"Failed to retrieve the documentation. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/did/did.py b/api/core/tools/provider/builtin/did/did.py index b4bf172131..5af78794f6 100644 --- a/api/core/tools/provider/builtin/did/did.py +++ b/api/core/tools/provider/builtin/did/did.py @@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the D-ID talks tool - TalksTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png", "text_input": "Hello, welcome to use D-ID tool in Dify", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py index 964e82b729..c68878630d 100644 --- a/api/core/tools/provider/builtin/did/did_appx.py +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -12,14 +12,14 @@ logger = logging.getLogger(__name__) class DIDApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.d-id.com' + self.base_url = base_url or "https://api.d-id.com" if not self.api_key: - raise ValueError('API key is required') + raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'} + headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( @@ -44,44 +44,44 @@ class DIDApp: return None def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/talks' + endpoint = f"{self.base_url}/talks" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval) + return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval) return id def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/animations' + endpoint = f"{self.base_url}/animations" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval) + return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval) return id def check_did_status(self, target: str, id: str): - endpoint = f'{self.base_url}/{target}/{id}' + endpoint = f"{self.base_url}/{target}/{id}" headers = self._prepare_headers() - response = self._request('GET', endpoint, headers=headers) + response = self._request("GET", endpoint, headers=headers) if response is None: - raise HTTPError(f'Failed to check status for talks {id} after multiple retries') + raise HTTPError(f"Failed to check status for talks {id} after multiple retries") return response def _monitor_job_status(self, target: str, id: str, poll_interval: int): while True: status = self.check_did_status(target=target, id=id) - if status['status'] == 'done': + if status["status"] == "done": return status - elif status['status'] == 'error' or status['status'] == 'rejected': - raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') + elif status["status"] == "error" or status["status"] == "rejected": + raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}') time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/did/tools/animations.py b/api/core/tools/provider/builtin/did/tools/animations.py index e1d9de603f..bc9d17e40d 100644 --- a/api/core/tools/provider/builtin/did/tools/animations.py +++ b/api/core/tools/provider/builtin/did/tools/animations.py @@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None config = { - 'stitch': tool_parameters.get('stitch', True), - 'mute': tool_parameters.get('mute'), - 'result_format': tool_parameters.get('result_format') or 'mp4', + "stitch": tool_parameters.get("stitch", True), + "mute": tool_parameters.get("mute"), + "result_format": tool_parameters.get("result_format") or "mp4", } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if config.get('logo_url'): - if not config.get('logo_x'): - raise ValueError('Logo X position is required when logo URL is provided') - if not config.get('logo_y'): - raise ValueError('Logo Y position is required when logo URL is provided') + if config.get("logo_url"): + if not config.get("logo_x"): + raise ValueError("Logo X position is required when logo URL is provided") + if not config.get("logo_y"): + raise ValueError("Logo Y position is required when logo URL is provided") animations_result = app.animations(params=options, wait=True) @@ -44,6 +44,6 @@ class AnimationsTool(BuiltinTool): animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4) if not animations_result: - return self.create_text_message('D-ID animations request failed.') + return self.create_text_message("D-ID animations request failed.") return self.create_text_message(animations_result) diff --git a/api/core/tools/provider/builtin/did/tools/talks.py b/api/core/tools/provider/builtin/did/tools/talks.py index 06b2c4cb2f..d6f0c7ff17 100644 --- a/api/core/tools/provider/builtin/did/tools/talks.py +++ b/api/core/tools/provider/builtin/did/tools/talks.py @@ -10,49 +10,49 @@ class TalksTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None script = { - 'type': tool_parameters.get('script_type') or 'text', - 'input': tool_parameters.get('text_input'), - 'audio_url': tool_parameters.get('audio_url'), - 'reduce_noise': tool_parameters.get('audio_reduce_noise', False), + "type": tool_parameters.get("script_type") or "text", + "input": tool_parameters.get("text_input"), + "audio_url": tool_parameters.get("audio_url"), + "reduce_noise": tool_parameters.get("audio_reduce_noise", False), } - script = {k: v for k, v in script.items() if v is not None and v != ''} + script = {k: v for k, v in script.items() if v is not None and v != ""} config = { - 'stitch': tool_parameters.get('stitch', True), - 'sharpen': tool_parameters.get('sharpen'), - 'fluent': tool_parameters.get('fluent'), - 'result_format': tool_parameters.get('result_format') or 'mp4', - 'pad_audio': tool_parameters.get('pad_audio'), - 'driver_expressions': driver_expressions, + "stitch": tool_parameters.get("stitch", True), + "sharpen": tool_parameters.get("sharpen"), + "fluent": tool_parameters.get("fluent"), + "result_format": tool_parameters.get("result_format") or "mp4", + "pad_audio": tool_parameters.get("pad_audio"), + "driver_expressions": driver_expressions, } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'script': script, - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "script": script, + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if script.get('type') == 'audio': - script.pop('input', None) - if not script.get('audio_url'): - raise ValueError('Audio URL is required for audio script type') + if script.get("type") == "audio": + script.pop("input", None) + if not script.get("audio_url"): + raise ValueError("Audio URL is required for audio script type") - if script.get('type') == 'text': - script.pop('audio_url', None) - script.pop('reduce_noise', None) - if not script.get('input'): - raise ValueError('Text input is required for text script type') + if script.get("type") == "text": + script.pop("audio_url", None) + script.pop("reduce_noise", None) + if not script.get("input"): + raise ValueError("Text input is required for text script type") talks_result = app.talks(params=options, wait=True) @@ -60,6 +60,6 @@ class TalksTool(BuiltinTool): talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4) if not talks_result: - return self.create_text_message('D-ID talks request failed.') + return self.create_text_message("D-ID talks request failed.") return self.create_text_message(talks_result) diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py index c247c3bd6b..f33ad5be59 100644 --- a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py @@ -13,38 +13,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class DingTalkGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - Dingtalk custom group robot API docs: - https://open.dingtalk.com/document/orgapp/custom-robot-access + invoke tools + Dingtalk custom group robot API docs: + https://open.dingtalk.com/document/orgapp/custom-robot-access """ - content = tool_parameters.get('content') + content = tool_parameters.get("content") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - access_token = tool_parameters.get('access_token') + access_token = tool_parameters.get("access_token") if not access_token: - return self.create_text_message('Invalid parameter access_token. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter access_token. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - sign_secret = tool_parameters.get('sign_secret') + sign_secret = tool_parameters.get("sign_secret") if not sign_secret: - return self.create_text_message('Invalid parameter sign_secret. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter sign_secret. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - msgtype = 'text' - api_url = 'https://oapi.dingtalk.com/robot/send' + msgtype = "text" + api_url = "https://oapi.dingtalk.com/robot/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'access_token': access_token, + "access_token": access_token, } self._apply_security_mechanism(params, sign_secret) @@ -53,7 +58,7 @@ class DingTalkGroupBotTool(BuiltinTool): "msgtype": msgtype, "text": { "content": content, - } + }, } try: @@ -62,7 +67,8 @@ class DingTalkGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) @@ -70,14 +76,14 @@ class DingTalkGroupBotTool(BuiltinTool): def _apply_security_mechanism(params: dict[str, Any], sign_secret: str): try: timestamp = str(round(time.time() * 1000)) - secret_enc = sign_secret.encode('utf-8') - string_to_sign = f'{timestamp}\n{sign_secret}' - string_to_sign_enc = string_to_sign.encode('utf-8') + secret_enc = sign_secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{sign_secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - params['timestamp'] = timestamp - params['sign'] = sign + params["timestamp"] = timestamp + params["sign"] = sign except Exception: msg = "Failed to apply security mechanism to the request." logging.exception(msg) diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 2292e89fa6..8269167127 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -11,11 +11,10 @@ class DuckDuckGoProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py index 878b0d8645..8bdd638f4a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py @@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "model": tool_parameters.get('model'), + "keywords": tool_parameters.get("query"), + "model": tool_parameters.get("model"), } response = DDGS().chat(**query_dict) return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index bca53f6b4b..396570248a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -14,18 +14,17 @@ class DuckDuckGoImageSearchTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get('query'), - "timelimit": tool_parameters.get('timelimit'), - "size": tool_parameters.get('size'), - "max_results": tool_parameters.get('max_results'), + "keywords": tool_parameters.get("query"), + "timelimit": tool_parameters.get("timelimit"), + "size": tool_parameters.get("size"), + "max_results": tool_parameters.get("max_results"), } response = DDGS().images(**query_dict) result = [] for res in response: - res['transfer_method'] = FileTransferMethod.REMOTE_URL - msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=res.get('image'), - save_as='', - meta=res) + res["transfer_method"] = FileTransferMethod.REMOTE_URL + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res + ) result.append(msg) return result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index dfaeb734d8..cbd65d2e77 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - query = tool_parameters.get('query') - max_results = tool_parameters.get('max_results', 5) - require_summary = tool_parameters.get('require_summary', False) + query = tool_parameters.get("query") + max_results = tool_parameters.get("max_results", 5) + require_summary = tool_parameters.get("require_summary", False) response = DDGS().text(query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) @@ -34,7 +35,11 @@ class DuckDuckGoSearchTool(BuiltinTool): def summary_results(self, user_id: str, content: str, query: str) -> str: prompt = SUMMARY_PROMPT.format(query=query, content=content) - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=prompt), - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[ + SystemPromptMessage(content=prompt), + ], + stop=[], + ) return summary.message.content diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py index 9822b37cf0..396ce21b18 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py @@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "to": tool_parameters.get('translate_to'), + "keywords": tool_parameters.get("query"), + "to": tool_parameters.get("translate_to"), } - response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!') + response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!") return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py index e8ab02f55e..e82da8ca53 100644 --- a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py +++ b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py @@ -8,35 +8,35 @@ from core.tools.utils.uuid_utils import is_valid_uuid class FeishuGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot + invoke tools + API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot """ url = "https://open.feishu.cn/open-apis/bot/v2/hook" - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - msg_type = 'text' - api_url = f'{url}/{hook_key}' + msg_type = "text" + api_url = f"{url}/{hook_key}" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { "msg_type": msg_type, "content": { "text": content, - } + }, } try: @@ -45,6 +45,7 @@ class FeishuGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.py b/api/core/tools/provider/builtin/feishu_base/feishu_base.py index febb769ff8..04056af53b 100644 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.py +++ b/api/core/tools/provider/builtin/feishu_base/feishu_base.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class FeishuBaseProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: GetTenantAccessTokenTool() - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py index be43b43ce4..4a605fbffe 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py @@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class AddBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to add base record, status code: {res.status_code}, response: {res.text}") + f"Failed to add base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py index 639644e7f0..6b755e2007 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py @@ -8,28 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class CreateBaseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - name = tool_parameters.get('name', '') - folder_token = tool_parameters.get('folder_token', '') + name = tool_parameters.get("name", "") + folder_token = tool_parameters.get("folder_token", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "name": name, - "folder_token": folder_token - } + payload = {"name": name, "folder_token": folder_token} try: res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) @@ -38,6 +35,7 @@ class CreateBaseTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base, status code: {res.status_code}, response: {res.text}") + f"Failed to create base, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py index e9062e8730..b05d700113 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py @@ -8,37 +8,32 @@ from core.tools.tool.builtin_tool import BuiltinTool class CreateBaseTableTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - name = tool_parameters.get('name', '') + name = tool_parameters.get("name", "") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table": { - "name": name, - "fields": json.loads(fields) - } - } + payload = {"table": {"name": name, "fields": json.loads(fields)}} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -47,6 +42,7 @@ class CreateBaseTableTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base table, status code: {res.status_code}, response: {res.text}") + f"Failed to create base table, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py index aa13aad6fa..862eb2171b 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py @@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_ids = tool_parameters.get('record_ids', '') + record_ids = tool_parameters.get("record_ids", "") if not record_ids: - return self.create_text_message('Invalid parameter record_ids') + return self.create_text_message("Invalid parameter record_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "records": json.loads(record_ids) - } + payload = {"records": json.loads(record_ids)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base records, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py index c4280ebc21..f512186303 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py @@ -8,32 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_ids = tool_parameters.get('table_ids', '') + table_ids = tool_parameters.get("table_ids", "") if not table_ids: - return self.create_text_message('Invalid parameter table_ids') + return self.create_text_message("Invalid parameter table_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table_ids": json.loads(table_ids) - } + payload = {"table_ids": json.loads(table_ids)} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -42,6 +40,7 @@ class DeleteBaseTablesTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py index de70f2ed93..f664bbeed0 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py @@ -8,22 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetBaseInfoTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: @@ -33,6 +33,7 @@ class GetBaseInfoTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get base info, status code: {res.status_code}, response: {res.text}") + f"Failed to get base info, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get base info. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py index 88507bda60..2ea61d0068 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py @@ -8,27 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetTenantAccessTokenTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" - app_id = tool_parameters.get('app_id', '') + app_id = tool_parameters.get("app_id", "") if not app_id: - return self.create_text_message('Invalid parameter app_id') + return self.create_text_message("Invalid parameter app_id") - app_secret = tool_parameters.get('app_secret', '') + app_secret = tool_parameters.get("app_secret", "") if not app_secret: - return self.create_text_message('Invalid parameter app_secret') + return self.create_text_message("Invalid parameter app_secret") headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} """ { @@ -45,6 +42,7 @@ class GetTenantAccessTokenTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}") + f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py index 2a4229f137..e579d02f69 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py @@ -8,31 +8,31 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') - sort_condition = tool_parameters.get('sort_condition', '') - filter_condition = tool_parameters.get('filter_condition', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") + sort_condition = tool_parameters.get("sort_condition", "") + filter_condition = tool_parameters.get("filter_condition", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -40,22 +40,26 @@ class ListBaseRecordsTool(BuiltinTool): "page_size": page_size, } - payload = { - "automatic_fields": True - } + payload = {"automatic_fields": True} if sort_condition: payload["sort"] = json.loads(sort_condition) if filter_condition: payload["filter"] = json.loads(filter_condition) try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base records, status code: {res.status_code}, response: {res.text}") + f"Failed to list base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py index 6d82490eb3..4ec9a476bc 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py @@ -8,25 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -41,6 +41,7 @@ class ListBaseTablesTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to list base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py index bb4bd6c3a6..fb818f8380 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py @@ -8,40 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class ReadBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: - res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - timeout=30) + res = httpx.get( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30 + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to read base record, status code: {res.status_code}, response: {res.text}") + f"Failed to read base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py index 6551053ce2..6d7e33f3ff 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py @@ -8,49 +8,53 @@ from core.tools.tool.builtin_tool import BuiltinTool class UpdateBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - params=params, json=payload, timeout=30) + res = httpx.put( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to update base record, status code: {res.status_code}, response: {res.text}") + f"Failed to update base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py index c4f8f26e2c..b0a1e393eb 100644 --- a/api/core/tools/provider/builtin/feishu_document/feishu_document.py +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class FeishuDocumentProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 0ff82e621b..090a0828e8 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get('title') - content = tool_parameters.get('content') - folder_token = tool_parameters.get('folder_token') + title = tool_parameters.get("title") + content = tool_parameters.get("content") + folder_token = tool_parameters.get("folder_token") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py index 16ef90908b..83073e0822 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py @@ -7,11 +7,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class GetDocumentRawContentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') + document_id = tool_parameters.get("document_id") res = client.get_document_raw_content(document_id) - return self.create_json_message(res) \ No newline at end of file + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py index 97d17bdb04..8c0c4a3c97 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - page_size = tool_parameters.get('page_size', 500) - page_token = tool_parameters.get('page_token', '') + document_id = tool_parameters.get("document_id") + page_size = tool_parameters.get("page_size", 500) + page_token = tool_parameters.get("page_token", "") res = client.list_document_block(document_id, page_token, page_size) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py index 914a44dce6..6061250e48 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - content = tool_parameters.get('content') - position = tool_parameters.get('position') + document_id = tool_parameters.get("document_id") + content = tool_parameters.get("content") + position = tool_parameters.get("position") res = client.write_document(document_id, content, position) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py index 6d7fed330c..7b3adb9293 100644 --- a/api/core/tools/provider/builtin/feishu_message/feishu_message.py +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class FeishuMessageProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py index 74f6866ba3..1dd315d0e2 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -7,14 +7,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class SendBotMessageTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - receive_id_type = tool_parameters.get('receive_id_type') - receive_id = tool_parameters.get('receive_id') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + receive_id_type = tool_parameters.get("receive_id_type") + receive_id = tool_parameters.get("receive_id") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py index 7159f59ffa..44e70e0a15 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -6,14 +6,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class SendWebhookMessageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - webhook = tool_parameters.get('webhook') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + webhook = tool_parameters.get("webhook") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_webhook_message(webhook, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 24dc35759d..01455d7206 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the ScrapeTool, only scraping title for minimize content - ScrapeTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', - tool_parameters={ - "url": "https://google.com", - "onlyIncludeTags": 'title' - } + ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"} ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index 3b3f78731b..d9fb6f04bc 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -13,85 +13,83 @@ logger = logging.getLogger(__name__) class FirecrawlApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' + self.base_url = base_url or "https://api.firecrawl.dev" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( - self, - method: str, - url: str, - data: Mapping[str, Any] | None = None, - headers: Mapping[str, str] | None = None, - retries: int = 3, - backoff_factor: float = 0.3, + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, ) -> Mapping[str, Any] | None: if not headers: headers = self._prepare_headers() for i in range(retries): try: response = requests.request(method, url, json=data, headers=headers) - response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i < retries - 1: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None def scrape_url(self, url: str, **kwargs): - endpoint = f'{self.base_url}/v0/scrape' - data = {'url': url, **kwargs} + endpoint = f"{self.base_url}/v1/scrape" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to scrape URL after multiple retries") return response - def search(self, query: str, **kwargs): - endpoint = f'{self.base_url}/v0/search' - data = {'query': query, **kwargs} + def map(self, url: str, **kwargs): + endpoint = f"{self.base_url}/v1/map" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: - raise HTTPError("Failed to perform search after multiple retries") + raise HTTPError("Failed to perform map after multiple retries") return response def crawl_url( - self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs + self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f'{self.base_url}/v0/crawl' + endpoint = f"{self.base_url}/v1/crawl" headers = self._prepare_headers(idempotency_key) - data = {'url': url, **kwargs} + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response['jobId'] + elif response.get("success") == False: + raise HTTPError(f'Failed to crawl: {response.get("error")}') + job_id: str = response["id"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/status/{job_id}' - response = self._request('GET', endpoint) + endpoint = f"{self.base_url}/v1/crawl/{job_id}" + response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}' - response = self._request('DELETE', endpoint) + endpoint = f"{self.base_url}/v1/crawl/{job_id}" + response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") return response @@ -99,9 +97,9 @@ class FirecrawlApp: def _monitor_job_status(self, job_id: str, poll_interval: int): while True: status = self.check_crawl_status(job_id) - if status['status'] == 'completed': + if status["status"] == "completed": return status - elif status['status'] == 'failed': + elif status["status"] == "failed": raise HTTPError(f'Job {job_id} failed: {status["error"]}') time.sleep(poll_interval) @@ -109,7 +107,7 @@ class FirecrawlApp: def get_array_params(tool_parameters: dict[str, Any], key): param = tool_parameters.get(key) if param: - return param.split(',') + return param.split(",") def get_json_params(tool_parameters: dict[str, Any], key): @@ -119,6 +117,6 @@ def get_json_params(tool_parameters: dict[str, Any], key): # support both single quotes and double quotes param = param.replace("'", '"') param = json.loads(param) - except: + except Exception: raise ValueError(f"Invalid {key} format.") return param diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 08c40a4064..9675b8eb91 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -8,41 +8,38 @@ from core.tools.tool.builtin_tool import BuiltinTool class CrawlTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ - the crawlerOptions and pageOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - crawlerOptions = {} - pageOptions = {} - - wait_for_results = tool_parameters.get('wait_for_results', True) - - crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes') - crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes') - crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False) - crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth') - crawlerOptions['mode'] = tool_parameters.get('mode') - crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False) - crawlerOptions['limit'] = tool_parameters.get('limit', 5) - crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False) - crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False) - - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) - - crawl_result = app.crawl_url( - url=tool_parameters['url'], - wait=wait_for_results, - crawlerOptions=crawlerOptions, - pageOptions=pageOptions + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] ) + scrapeOptions = {} + payload = {} + + wait_for_results = tool_parameters.get("wait_for_results", True) + + payload["excludePaths"] = get_array_params(tool_parameters, "excludePaths") + payload["includePaths"] = get_array_params(tool_parameters, "includePaths") + payload["maxDepth"] = tool_parameters.get("maxDepth") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + payload["limit"] = tool_parameters.get("limit", 5) + payload["allowBackwardLinks"] = tool_parameters.get("allowBackwardLinks", False) + payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False) + payload["webhook"] = tool_parameters.get("webhook") + + scrapeOptions["formats"] = get_array_params(tool_parameters, "formats") + scrapeOptions["headers"] = get_json_params(tool_parameters, "headers") + scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags") + scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0) + scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in {None, ""}} + payload["scrapeOptions"] = scrapeOptions or None + + payload = {k: v for k, v in payload.items() if v not in {None, ""}} + + crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload) + return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 0c5399f973..0d7dbcac20 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -31,8 +31,21 @@ parameters: en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task. zh_Hans: 如果选择不等待,则会直接返回一个job_id,可以通过job_id查询爬取结果或取消爬取任务,这通常对于一个大型爬取任务来说非常有用。 form: form -############## Crawl Options ####################### - - name: includes +############## Payload ####################### + - name: excludePaths + type: string + label: + en_US: URL patterns to exclude + zh_Hans: 要排除的URL模式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Pages matching these patterns will be skipped. Example: blog/*, about/* + zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* + form: form + - name: includePaths type: string required: false label: @@ -46,30 +59,6 @@ parameters: Only pages matching these patterns will be crawled. Example: blog/*, about/* zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/* form: form - - name: excludes - type: string - label: - en_US: URL patterns to exclude - zh_Hans: 要排除的URL模式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - human_description: - en_US: | - Pages matching these patterns will be skipped. Example: blog/*, about/* - zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* - form: form - - name: returnOnlyUrls - type: boolean - default: false - label: - en_US: return Only Urls - zh_Hans: 仅返回URL - human_description: - en_US: | - If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents. - zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。 - form: form - name: maxDepth type: number label: @@ -80,27 +69,10 @@ parameters: zh_Hans: 相对于输入的URL,爬取的最大深度。maxDepth为0时,仅抓取输入的URL。maxDepth为1时,抓取输入的URL以及所有一级深层页面。maxDepth为2时,抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。 form: form min: 0 - - name: mode - type: select - required: false - form: form - options: - - value: default - label: - en_US: default - - value: fast - label: - en_US: fast - default: default - label: - en_US: Crawl Mode - zh_Hans: 爬取模式 - human_description: - en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites. - zh_Hans: 使用fast模式将不会使用其站点地图,比普通模式快4倍,但是可能不够准确,也不适用于大量js渲染的网站。 + default: 2 - name: ignoreSitemap type: boolean - default: false + default: true label: en_US: ignore Sitemap zh_Hans: 忽略站点地图 @@ -120,7 +92,7 @@ parameters: form: form min: 1 default: 5 - - name: allowBackwardCrawling + - name: allowBackwardLinks type: boolean default: false label: @@ -130,7 +102,7 @@ parameters: en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product' zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如,从'example.com/product/123'返回到'example.com/product' form: form - - name: allowExternalContentLinks + - name: allowExternalLinks type: boolean default: false label: @@ -140,7 +112,30 @@ parameters: en_US: Allows the crawler to follow links to external websites. zh_Hans: form: form -############## Page Options ####################### + - name: webhook + type: string + label: + en_US: webhook + human_description: + en_US: | + The URL to send the webhook to. This will trigger for crawl started (crawl.started) ,every page crawled (crawl.page) and when the crawl is completed (crawl.completed or crawl.failed). The response will be the same as the /scrape endpoint. + zh_Hans: 发送Webhook的URL。这将在开始爬取(crawl.started)、每爬取一个页面(crawl.page)以及爬取完成(crawl.completed或crawl.failed)时触发。响应将与/scrape端点相同。 + form: form +############## Scrape Options ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot + form: form - name: headers type: string label: @@ -155,30 +150,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags + - name: includeTags type: string label: - en_US: only Include Tags + en_US: Include Tags zh_Hans: 仅抓取这些标签 placeholder: en_US: Use commas to separate multiple tags @@ -189,6 +164,20 @@ parameters: zh_Hans: | 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer form: form + - name: excludeTags + type: string + label: + en_US: Exclude Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form - name: onlyMainContent type: boolean default: false @@ -199,40 +188,6 @@ parameters: en_US: Only return the main content of the page excluding headers, navs, footers, etc. zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 form: form - - name: removeTags - type: string - label: - en_US: remove Tags - zh_Hans: 要移除这些标签 - human_description: - en_US: | - Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false - label: - en_US: screenshot - zh_Hans: 截图 - human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 - form: form - name: waitFor type: number min: 0 diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py index fa6c1f87ee..0d2486c7ca 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py @@ -7,14 +7,15 @@ from core.tools.tool.builtin_tool import BuiltinTool class CrawlJobTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - operation = tool_parameters.get('operation', 'get') - if operation == 'get': - result = app.check_crawl_status(job_id=tool_parameters['job_id']) - elif operation == 'cancel': - result = app.cancel_crawl_job(job_id=tool_parameters['job_id']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + operation = tool_parameters.get("operation", "get") + if operation == "get": + result = app.check_crawl_status(job_id=tool_parameters["job_id"]) + elif operation == "cancel": + result = app.cancel_crawl_job(job_id=tool_parameters["job_id"]) else: - raise ValueError(f'Invalid operation: {operation}') + raise ValueError(f"Invalid operation: {operation}") return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.py b/api/core/tools/provider/builtin/firecrawl/tools/map.py new file mode 100644 index 0000000000..bdfb5faeb8 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class MapTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the api doc: + https://docs.firecrawl.dev/api-reference/endpoint/map + """ + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + payload = {} + payload["search"] = tool_parameters.get("search") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", True) + payload["includeSubdomains"] = tool_parameters.get("includeSubdomains", False) + payload["limit"] = tool_parameters.get("limit", 5000) + + map_result = app.map(url=tool_parameters["url"], **payload) + + return self.create_json_message(map_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.yaml b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml new file mode 100644 index 0000000000..9913756983 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml @@ -0,0 +1,59 @@ +identity: + name: map + author: hjlarry + label: + en_US: Map + zh_Hans: 地图式快爬 +description: + human: + en_US: Input a website and get all the urls on the website - extremly fast + zh_Hans: 输入一个网站,快速获取网站上的所有网址。 + llm: Input a website and get all the urls on the website - extremly fast +parameters: + - name: url + type: string + required: true + label: + en_US: Start URL + zh_Hans: 起始URL + human_description: + en_US: The base URL to start crawling from. + zh_Hans: 要爬取网站的起始URL。 + llm_description: The URL of the website that needs to be crawled. This is a required parameter. + form: llm + - name: search + type: string + label: + en_US: search + zh_Hans: 搜索查询 + human_description: + en_US: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + zh_Hans: 用于映射的搜索查询。在Alpha阶段,搜索功能的“智能”部分限制为最多100个搜索结果。然而,如果地图找到了更多结果,则不施加任何限制。 + llm_description: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + form: llm +############## Page Options ####################### + - name: ignoreSitemap + type: boolean + default: true + label: + en_US: ignore Sitemap + zh_Hans: 忽略站点地图 + human_description: + en_US: Ignore the website sitemap when crawling. + zh_Hans: 爬取时忽略网站站点地图。 + form: form + - name: includeSubdomains + type: boolean + default: false + label: + en_US: include Subdomains + zh_Hans: 包含子域名 + form: form + - name: limit + type: number + min: 0 + default: 5000 + label: + en_US: Maximum results + zh_Hans: 最大结果数量 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 91412da548..538b4a1fcb 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: """ - the pageOptions and extractorOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) - pageOptions = {} - extractorOptions = {} + payload = {} + extract = {} - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) + payload["formats"] = get_array_params(tool_parameters, "formats") + payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True) + payload["includeTags"] = get_array_params(tool_parameters, "includeTags") + payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + payload["headers"] = get_json_params(tool_parameters, "headers") + payload["waitFor"] = tool_parameters.get("waitFor", 0) + payload["timeout"] = tool_parameters.get("timeout", 30000) - extractorOptions['mode'] = tool_parameters.get('mode', '') - extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '') - extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema') + extract["schema"] = get_json_params(tool_parameters, "schema") + extract["systemPrompt"] = tool_parameters.get("systemPrompt") + extract["prompt"] = tool_parameters.get("prompt") + extract = {k: v for k, v in extract.items() if v not in {None, ""}} + payload["extract"] = extract or None - crawl_result = app.scrape_url(url=tool_parameters['url'], - pageOptions=pageOptions, - extractorOptions=extractorOptions) + payload = {k: v for k, v in payload.items() if v not in {None, ""}} - return self.create_json_message(crawl_result) + crawl_result = app.scrape_url(url=tool_parameters["url"], **payload) + markdown_result = crawl_result.get("data", {}).get("markdown", "") + return [self.create_text_message(markdown_result), self.create_json_message(crawl_result)] diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 598429de5e..8f1f1348a4 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -6,8 +6,8 @@ identity: zh_Hans: 单页面抓取 description: human: - en_US: Extract data from a single URL. - zh_Hans: 从单个URL抓取数据。 + en_US: Turn any url into clean data. + zh_Hans: 将任何网址转换为干净的数据。 llm: This tool is designed to scrape URL and output the content in Markdown format. parameters: - name: url @@ -21,7 +21,59 @@ parameters: zh_Hans: 要抓取并提取数据的网站URL。 llm_description: The URL of the website that needs to be crawled. This is a required parameter. form: llm -############## Page Options ####################### +############## Payload ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + form: form + - name: onlyMainContent + type: boolean + default: false + label: + en_US: only Main Content + zh_Hans: 仅抓取主要内容 + human_description: + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 + form: form + - name: includeTags + type: string + label: + en_US: Include Tags + zh_Hans: 仅抓取这些标签 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + form: form + - name: excludeTags + type: string + label: + en_US: Exclude Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form - name: headers type: string label: @@ -36,87 +88,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags - type: string - label: - en_US: only Include Tags - zh_Hans: 仅抓取这些标签 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - human_description: - en_US: | - Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - form: form - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: removeTags - type: string - label: - en_US: remove Tags - zh_Hans: 要移除这些标签 - human_description: - en_US: | - Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false - label: - en_US: screenshot - zh_Hans: 截图 - human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 - form: form - name: waitFor type: number min: 0 + default: 0 label: en_US: wait For zh_Hans: 等待时间 @@ -124,57 +99,54 @@ parameters: en_US: Wait x amount of milliseconds for the page to load to fetch content. zh_Hans: 等待x毫秒以使页面加载并获取内容。 form: form + - name: timeout + type: number + min: 0 + default: 30000 + label: + en_US: Timeout + human_description: + en_US: Timeout in milliseconds for the request. + zh_Hans: 请求的超时时间(以毫秒为单位)。 + form: form ############## Extractor Options ####################### - - name: mode - type: select - options: - - value: markdown - label: - en_US: markdown - - value: llm-extraction - label: - en_US: llm-extraction - - value: llm-extraction-from-raw-html - label: - en_US: llm-extraction-from-raw-html - - value: llm-extraction-from-markdown - label: - en_US: llm-extraction-from-markdown - label: - en_US: Extractor Mode - zh_Hans: 提取模式 - human_description: - en_US: | - The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM. - zh_Hans: 使用的提取模式。“markdown”:返回抓取的markdown内容,不执行LLM提取。“llm-extractioin”:使用LLM按Extractor Schema从内容中提取信息。 - form: form - - name: extractionPrompt - type: string - label: - en_US: Extractor Prompt - zh_Hans: 提取时的提示词 - human_description: - en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes. - zh_Hans: 当使用LLM提取模式时,用于给LLM描述提取规则。 - form: form - - name: extractionSchema + - name: schema type: string label: en_US: Extractor Schema zh_Hans: 提取时的结构 placeholder: en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 human_description: en_US: | - The schema for the data to be extracted, required only for LLM extraction modes. Example: { + The schema for the data to be extracted. Example: { "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } zh_Hans: | - 当使用LLM提取模式时,使用该结构去提取,示例:{ + 使用该结构去提取,示例:{ "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } form: form + - name: systemPrompt + type: string + label: + en_US: Extractor System Prompt + zh_Hans: 提取时的系统提示词 + human_description: + en_US: The system prompt to use for the extraction. + zh_Hans: 用于提取的系统提示。 + form: form + - name: prompt + type: string + label: + en_US: Extractor Prompt + zh_Hans: 提取时的提示词 + human_description: + en_US: The prompt to use for the extraction without a schema. + zh_Hans: 用于无schema时提取的提示词 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py deleted file mode 100644 index e2b2ac6b4d..0000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp -from core.tools.tool.builtin_tool import BuiltinTool - - -class SearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - """ - the pageOptions and searchOptions comes from doc here: - https://docs.firecrawl.dev/api-reference/endpoint/search - """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - pageOptions = {} - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True) - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - searchOptions = {'limit': tool_parameters.get('limit')} - search_result = app.search( - query=tool_parameters['keyword'], - pageOptions=pageOptions, - searchOptions=searchOptions - ) - - return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml b/api/core/tools/provider/builtin/firecrawl/tools/search.yaml deleted file mode 100644 index 29df0cfaaa..0000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml +++ /dev/null @@ -1,75 +0,0 @@ -identity: - name: search - author: ahasasjeb - label: - en_US: Search - zh_Hans: 搜索 -description: - human: - en_US: Search, and output in Markdown format - zh_Hans: 搜索,并且以Markdown格式输出 - llm: This tool can perform online searches and convert the results to Markdown format. -parameters: - - name: keyword - type: string - required: true - label: - en_US: keyword - zh_Hans: 关键词 - human_description: - en_US: Input keywords to use Firecrawl API for search. - zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。 - llm_description: Efficiently extract keywords from user text. - form: llm -############## Page Options ####################### - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: fetchPageContent - type: boolean - default: true - label: - en_US: fetch Page Content - zh_Hans: 抓取页面内容 - human_description: - en_US: Fetch the content of each page. If false, defaults to a basic fast serp API. - zh_Hans: 获取每个页面的内容。如果为否,则使用基本的快速搜索结果页面API。 - form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form -############## Search Options ####################### - - name: limit - type: number - min: 0 - label: - en_US: Maximum results - zh_Hans: 最大结果数量 - human_description: - en_US: Maximum number of results. Max is 20 during beta. - zh_Hans: 最大结果数量。在测试阶段,最大为20。 - form: form diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index b55d93e07b..49a8e537fb 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -9,17 +9,19 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GaodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'api_key' not in credentials or not credentials.get('api_key'): + if "api_key" not in credentials or not credentials.get("api_key"): raise ToolProviderCredentialValidationError("Gaode API key is required.") try: - response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" - "".format(address=urllib.parse.quote('广东省广州市天河区广州塔'), - apikey=credentials.get('api_key'))) - if response.status_code == 200 and (response.json()).get('info') == 'OK': + response = requests.get( + url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}".format( + address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key") + ) + ) + if response.status_code == 200 and (response.json()).get("info") == "OK": pass else: - raise ToolProviderCredentialValidationError((response.json()).get('info')) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index efd11cedce..ea06e2ce61 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -8,50 +8,57 @@ from core.tools.tool.builtin_tool import BuiltinTool class GaodeRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - city = tool_parameters.get('city', '') + city = tool_parameters.get("city", "") if not city: - return self.create_text_message('Please tell me your city') + return self.create_text_message("Please tell me your city") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Gaode API key is required.") try: s = requests.session() - api_domain = 'https://restapi.amap.com/v3' - city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"}, - url="{url}/config/district?keywords={keywords}" - "&subdistrict=0&extensions=base&key={apikey}" - "".format(url=api_domain, keywords=city, - apikey=self.runtime.credentials.get('api_key'))) + api_domain = "https://restapi.amap.com/v3" + city_response = s.request( + method="GET", + headers={"Content-Type": "application/json; charset=utf-8"}, + url="{url}/config/district?keywords={keywords}&subdistrict=0&extensions=base&key={apikey}".format( + url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key") + ), + ) City_data = city_response.json() - if city_response.status_code == 200 and City_data.get('info') == 'OK': - if len(City_data.get('districts')) > 0: - CityCode = City_data['districts'][0]['adcode'] - weatherInfo_response = s.request(method='GET', - url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" - "".format(url=api_domain, citycode=CityCode, - apikey=self.runtime.credentials.get('api_key'))) + if city_response.status_code == 200 and City_data.get("info") == "OK": + if len(City_data.get("districts")) > 0: + CityCode = City_data["districts"][0]["adcode"] + weatherInfo_response = s.request( + method="GET", + url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" + "".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")), + ) weatherInfo_data = weatherInfo_response.json() - if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': + if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK": contents = [] - if len(weatherInfo_data.get('forecasts')) > 0: - for item in weatherInfo_data['forecasts'][0]['casts']: + if len(weatherInfo_data.get("forecasts")) > 0: + for item in weatherInfo_data["forecasts"][0]["casts"]: content = {} - content['date'] = item.get('date') - content['week'] = item.get('week') - content['dayweather'] = item.get('dayweather') - content['daytemp_float'] = item.get('daytemp_float') - content['daywind'] = item.get('daywind') - content['nightweather'] = item.get('nightweather') - content['nighttemp_float'] = item.get('nighttemp_float') + content["date"] = item.get("date") + content["week"] = item.get("week") + content["dayweather"] = item.get("dayweather") + content["daytemp_float"] = item.get("daytemp_float") + content["daywind"] = item.get("daywind") + content["nightweather"] = item.get("nightweather") + content["nighttemp_float"] = item.get("nighttemp_float") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) s.close() - return self.create_text_message(f'No weather information for {city} was found.') + return self.create_text_message(f"No weather information for {city} was found.") except Exception as e: return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.py b/api/core/tools/provider/builtin/getimgai/getimgai.py index c81d5fa333..bbd07d120f 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai.py @@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the text2image tool - Text2ImageTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "prompt": "A fire egg", "response_format": "url", "style": "photorealism", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py index e28c57649c..0e95a5f654 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py @@ -8,18 +8,16 @@ from requests.exceptions import HTTPError logger = logging.getLogger(__name__) + class GetImgAIApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.getimg.ai/v1' + self.base_url = base_url or "https://api.getimg.ai/v1" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return headers def _request( @@ -38,22 +36,20 @@ class GetImgAIApp: return response.json() except requests.exceptions.RequestException as e: if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None - def text2image( - self, mode: str, **kwargs - ): - data = kwargs['params'] - if not data.get('prompt'): + def text2image(self, mode: str, **kwargs): + data = kwargs["params"] + if not data.get("prompt"): raise ValueError("Prompt is required") - endpoint = f'{self.base_url}/{mode}/text-to-image' + endpoint = f"{self.base_url}/{mode}/text-to-image" headers = self._prepare_headers() logger.debug(f"Send request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate getimg.ai after multiple retries") return response diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.py b/api/core/tools/provider/builtin/getimgai/tools/text2image.py index dad7314479..c556749552 100644 --- a/api/core/tools/provider/builtin/getimgai/tools/text2image.py +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.py @@ -7,28 +7,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class Text2ImageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url']) + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = GetImgAIApp( + api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"] + ) options = { - 'style': tool_parameters.get('style'), - 'prompt': tool_parameters.get('prompt'), - 'aspect_ratio': tool_parameters.get('aspect_ratio'), - 'output_format': tool_parameters.get('output_format', 'jpeg'), - 'response_format': tool_parameters.get('response_format', 'url'), - 'width': tool_parameters.get('width'), - 'height': tool_parameters.get('height'), - 'steps': tool_parameters.get('steps'), - 'negative_prompt': tool_parameters.get('negative_prompt'), - 'prompt_2': tool_parameters.get('prompt_2'), + "style": tool_parameters.get("style"), + "prompt": tool_parameters.get("prompt"), + "aspect_ratio": tool_parameters.get("aspect_ratio"), + "output_format": tool_parameters.get("output_format", "jpeg"), + "response_format": tool_parameters.get("response_format", "url"), + "width": tool_parameters.get("width"), + "height": tool_parameters.get("height"), + "steps": tool_parameters.get("steps"), + "negative_prompt": tool_parameters.get("negative_prompt"), + "prompt_2": tool_parameters.get("prompt_2"), } options = {k: v for k, v in options.items() if v} - text2image_result = app.text2image( - mode=tool_parameters.get('mode', 'essential-v2'), - params=options, - wait=True - ) + text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True) if not isinstance(text2image_result, str): text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4) diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index b19f0896f8..87a34ac3e8 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -7,25 +7,25 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GithubProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Github API Access Tokens is required.") - if 'api_version' not in credentials or not credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in credentials or not credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = credentials.get('api_version') + api_version = credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } response = requests.get( - url="https://api.github.com/search/users?q={account}".format(account='charli117'), - headers=headers) + url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers + ) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index 305bf08ce8..32f9922e65 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -10,53 +10,61 @@ from core.tools.tool.builtin_tool import BuiltinTool class GithubRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - top_n = tool_parameters.get('top_n', 5) - query = tool_parameters.get('query', '') + top_n = tool_parameters.get("top_n", 5) + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input symbol') + return self.create_text_message("Please input symbol") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Github API Access Tokens is required.") - if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = self.runtime.credentials.get('api_version') + api_version = self.runtime.credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } s = requests.session() - api_domain = 'https://api.github.com' - response = s.request(method='GET', headers=headers, - url=f"{api_domain}/search/repositories?" - f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") + api_domain = "https://api.github.com" + response = s.request( + method="GET", + headers=headers, + url=f"{api_domain}/search/repositories?q={quote(query)}&sort=stars&per_page={top_n}&order=desc", + ) response_data = response.json() - if response.status_code == 200 and isinstance(response_data.get('items'), list): + if response.status_code == 200 and isinstance(response_data.get("items"), list): contents = [] - if len(response_data.get('items')) > 0: - for item in response_data.get('items'): + if len(response_data.get("items")) > 0: + for item in response_data.get("items"): content = {} - updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") - content['owner'] = item['owner']['login'] - content['name'] = item['name'] - content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description'] - content['url'] = item['html_url'] - content['star'] = item['watchers'] - content['forks'] = item['forks'] - content['updated'] = updated_at_object.strftime("%Y-%m-%d") + updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ") + content["owner"] = item["owner"]["login"] + content["name"] = item["name"] + content["description"] = ( + item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"] + ) + content["url"] = item["html_url"] + content["star"] = item["watchers"] + content["forks"] = item["forks"] + content["updated"] = updated_at_object.strftime("%Y-%m-%d") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) else: - return self.create_text_message(f'No items related to {query} were found.') + return self.create_text_message(f"No items related to {query} were found.") else: - return self.create_text_message((response.json()).get('message')) + return self.create_text_message((response.json()).get("message")) except Exception as e: return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py index 0c13ec662a..9bd4a0bd52 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.py +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -9,13 +9,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GitlabProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") - - if 'site_url' not in credentials or not credentials.get('site_url'): - site_url = 'https://gitlab.com' + + if "site_url" not in credentials or not credentials.get("site_url"): + site_url = "https://gitlab.com" else: - site_url = credentials.get('site_url') + site_url = credentials.get("site_url") try: headers = { @@ -23,12 +23,10 @@ class GitlabProvider(BuiltinToolProviderController): "Authorization": f"Bearer {credentials.get('access_tokens')}", } - response = requests.get( - url= f"{site_url}/api/v4/user", - headers=headers) + response = requests.get(url=f"{site_url}/api/v4/user", headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index 0824eb3a26..45ab15f437 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -1,4 +1,5 @@ import json +import urllib.parse from datetime import datetime, timedelta from typing import Any, Union @@ -9,103 +10,133 @@ from core.tools.tool.builtin_tool import BuiltinTool class GitlabCommitsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") + employee = tool_parameters.get("employee", "") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + change_type = tool_parameters.get("change_type", "all") - project = tool_parameters.get('project', '') - employee = tool_parameters.get('employee', '') - start_time = tool_parameters.get('start_time', '') - end_time = tool_parameters.get('end_time', '') - change_type = tool_parameters.get('change_type', 'all') - - if not project: - return self.create_text_message('Project is required') + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get commit content - result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) + if repository: + result = self.fetch_commits( + site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True + ) + else: + result = self.fetch_commits( + site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False + ) return [self.create_json_message(item) for item in result] - - def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: + + def fetch_commits( + self, + site_url: str, + access_token: str, + identifier: str, + employee: str, + start_time: str, + end_time: str, + change_type: str, + is_repository: bool, + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # Get all of projects - url = f"{domain}/api/v4/projects" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits" + else: + # Get all projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() - filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier] - for project in filtered_projects: - project_id = project['id'] - project_name = project['name'] - print(f"Project: {project_name}") + for project in filtered_projects: + project_id = project["id"] + project_name = project["name"] + print(f"Project: {project_name}") - # Get all of project commits - commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = { - 'since': start_time, - 'until': end_time - } - if employee: - params['author'] = employee + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - commits_response = requests.get(commits_url, headers=headers, params=params) - commits_response.raise_for_status() - commits = commits_response.json() + params = {"since": start_time, "until": end_time} + if employee: + params["author"] = employee - for commit in commits: - commit_sha = commit['id'] - author_name = commit['author_name'] + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + for commit in commits: + commit_sha = commit["id"] + author_name = commit["author_name"] + + if is_repository: + diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff" + else: diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" - diff_response = requests.get(diff_url, headers=headers) - diff_response.raise_for_status() - diffs = diff_response.json() - - for diff in diffs: - # Calculate code lines of changed - added_lines = diff['diff'].count('\n+') - removed_lines = diff['diff'].count('\n-') - total_changes = added_lines + removed_lines - if change_type == "new": - if added_lines > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code - }) - else: - if total_changes > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) - final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code_escaped - }) + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Calculate code lines of changes + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") + total_changes = added_lines + removed_lines + + if change_type == "new": + if added_lines > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}) + else: + if total_changes > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml index d38d943958..669378ac97 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -21,9 +21,20 @@ parameters: zh_Hans: 员工用户名 llm_description: User name for GitLab form: llm + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目名 diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 7fa1d0d112..1e77f3c6df 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -1,3 +1,4 @@ +import urllib.parse from typing import Any, Union import requests @@ -7,47 +8,85 @@ from core.tools.tool.builtin_tool import BuiltinTool class GitlabFilesTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - branch = tool_parameters.get('branch', '') - path = tool_parameters.get('path', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") + branch = tool_parameters.get("branch", "") + path = tool_parameters.get("path", "") - - if not project: - return self.create_text_message('Project is required') + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not branch: - return self.create_text_message('Branch is required') - + return self.create_text_message("Branch is required") if not path: - return self.create_text_message('Path is required') + return self.create_text_message("Path is required") - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - - # Get project ID from project name - project_id = self.get_project_id(site_url, access_token, project) - if not project_id: - return self.create_text_message(f"Project '{project}' not found.") + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" - # Get commit content - result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + # Get file content + if repository: + result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True) + else: + result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False) return [self.create_json_message(item) for item in result] - - def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split('/', 1) - if len(parts) < 2: - return None, None - return parts[0], parts[1] + + def fetch_files( + self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool + ) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}" + else: + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, identifier) + if not project_id: + return self.create_text_message(f"Project '{identifier}' not found.") + tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + + response = requests.get(tree_url, headers=headers) + response.raise_for_status() + items = response.json() + + for item in items: + item_path = item["path"] + if item["type"] == "tree": # It's a directory + results.extend( + self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository) + ) + else: # It's a file + if is_repository: + file_url = ( + f"{domain}/api/v4/projects/{encoded_identifier}/repository/files" + f"/{item_path}/raw?ref={branch}" + ) + else: + file_url = ( + f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + ) + + file_response = requests.get(file_url, headers=headers) + file_response.raise_for_status() + file_content = file_response.text + results.append({"path": item_path, "branch": branch, "content": file_content}) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: headers = {"PRIVATE-TOKEN": access_token} @@ -57,39 +96,8 @@ class GitlabFilesTool(BuiltinTool): response.raise_for_status() projects = response.json() for project in projects: - if project['name'] == project_name: - return project['id'] + if project["name"] == project_name: + return project["id"] except requests.RequestException as e: print(f"Error fetching project ID from GitLab: {e}") return None - - def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: - domain = site_url - headers = {"PRIVATE-TOKEN": access_token} - results = [] - - try: - # List files and directories in the given path - url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" - response = requests.get(url, headers=headers) - response.raise_for_status() - items = response.json() - - for item in items: - item_path = item['path'] - if item['type'] == 'tree': # It's a directory - results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) - else: # It's a file - file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" - file_response = requests.get(file_url, headers=headers) - file_response.raise_for_status() - file_content = file_response.text - results.append({ - "path": item_path, - "branch": branch, - "content": file_content - }) - except requests.RequestException as e: - print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml index d99b6254c1..4c733673f1 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -10,9 +10,20 @@ description: zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 llm: A tool for query GitLab files, Input should be a exists file or directory path. parameters: + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目 diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 8f4b9a4a4e..6b5395f9d3 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -13,12 +13,8 @@ class GoogleProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 09d0326fb4..a9f65925d8 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -9,7 +9,6 @@ SERP_API_URL = "https://serpapi.com/search" class GoogleSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledge_graph" in response: @@ -17,25 +16,23 @@ class GoogleSearchTool(BuiltinTool): result["description"] = response["knowledge_graph"].get("description", "") if "organic_results" in response: result["organic_results"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic_results"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: params = { - "api_key": self.runtime.credentials['serpapi_api_key'], - "q": tool_parameters['query'], + "api_key": self.runtime.credentials["serpapi_api_key"], + "q": tool_parameters["query"], "engine": "google", "google_domain": "google.com", "gl": "us", - "hl": "en" + "hl": "en", } response = requests.get(url=SERP_API_URL, params=params) response.raise_for_status() diff --git a/api/core/tools/provider/builtin/google_translate/google_translate.py b/api/core/tools/provider/builtin/google_translate/google_translate.py index f6e1d65834..ea53aa4eeb 100644 --- a/api/core/tools/provider/builtin/google_translate/google_translate.py +++ b/api/core/tools/provider/builtin/google_translate/google_translate.py @@ -8,10 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - GoogleTranslate().invoke(user_id='', - tool_parameters={ - "content": "这是一段测试文本", - "dest": "en" - }) + GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/google_translate/tools/translate.py b/api/core/tools/provider/builtin/google_translate/tools/translate.py index 4314182b06..ea3f2077d5 100644 --- a/api/core/tools/provider/builtin/google_translate/tools/translate.py +++ b/api/core/tools/provider/builtin/google_translate/tools/translate.py @@ -7,46 +7,41 @@ from core.tools.tool.builtin_tool import BuiltinTool class GoogleTranslate(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - dest = tool_parameters.get('dest', '') + dest = tool_parameters.get("dest", "") if not dest: - return self.create_text_message('Invalid parameter destination language') + return self.create_text_message("Invalid parameter destination language") try: result = self._translate(content, dest) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Translation service error, please check the network') + return self.create_text_message("Translation service error, please check the network") def _translate(self, content: str, dest: str) -> str: try: url = "https://translate.googleapis.com/translate_a/single" - params = { - "client": "gtx", - "sl": "auto", - "tl": dest, - "dt": "t", - "q": content - } + params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content} headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" } - response_json = requests.get( - url, params=params, headers=headers).json() + response_json = requests.get(url, params=params, headers=headers).json() result = response_json[0] - translated_text = ''.join([item[0] for item in result if item[0]]) + translated_text = "".join([item[0] for item in result if item[0]]) return str(translated_text) except Exception as e: return str(e) diff --git a/api/core/tools/provider/builtin/hap/hap.py b/api/core/tools/provider/builtin/hap/hap.py index e0a48e05a5..cbdf950465 100644 --- a/api/core/tools/provider/builtin/hap/hap.py +++ b/api/core/tools/provider/builtin/hap/hap.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class HapProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index 0e101dc67d..597adc91db 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -8,41 +8,40 @@ from core.tools.tool.builtin_tool import BuiltinTool class AddWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Worksheet ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/addRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to add the new record. {res_json['error_msg']}") return self.create_text_message(f"New record added successfully. The record ID is {res_json['data']}.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index ba25952c9f..5d42af4c49 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/deleteRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=30) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to delete the record. {res_json['error_msg']}") return self.create_text_message("Successfully deleted the record.") except httpx.RequestError as e: return self.create_text_message(f"Failed to delete the record, request error: {e}") except Exception as e: - return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") \ No newline at end of file + return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 2c46d9dd4e..6887b8b4e9 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -8,43 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetWorksheetFieldsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Worksheet ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to get the worksheet information. {res_json['error_msg']}") - - fields_json, fields_table = self.get_controls(res_json['data']['controls']) - result_type = tool_parameters.get('result_type', 'table') + + fields_json, fields_table = self.get_controls(res_json["data"]["controls"]) + result_type = tool_parameters.get("result_type", "table") return self.create_text_message( - text=json.dumps(fields_json, ensure_ascii=False) if result_type == 'json' else fields_table + text=json.dumps(fields_json, ensure_ascii=False) if result_type == "json" else fields_table ) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet information, request error: {e}") @@ -88,61 +87,66 @@ class GetWorksheetFieldsTool(BuiltinTool): 50: "Text", 51: "Query Record", } - return field_type_map.get(field_type_id, '') + return field_type_map.get(field_type_id, "") def get_controls(self, controls: list) -> dict: fields = [] - fields_list = ['|fieldId|fieldName|fieldType|fieldTypeId|description|options|','|'+'---|'*6] + fields_list = ["|fieldId|fieldName|fieldType|fieldTypeId|description|options|", "|" + "---|" * 6] for control in controls: - if control['type'] in self._get_ignore_types(): + if control["type"] in self._get_ignore_types(): continue - field_type_id = control['type'] - field_type = self.get_field_type_by_id(control['type']) + field_type_id = control["type"] + field_type = self.get_field_type_by_id(control["type"]) if field_type_id == 30: - source_type = control['sourceControl']['type'] + source_type = control["sourceControl"]["type"] if source_type in self._get_ignore_types(): continue else: field_type_id = source_type field_type = self.get_field_type_by_id(source_type) field = { - 'id': control['controlId'], - 'name': control['controlName'], - 'type': field_type, - 'typeId': field_type_id, - 'description': control['remark'].replace('\n', ' ').replace('\t', ' '), - 'options': self._extract_options(control), + "id": control["controlId"], + "name": control["controlName"], + "type": field_type, + "typeId": field_type_id, + "description": control["remark"].replace("\n", " ").replace("\t", " "), + "options": self._extract_options(control), } fields.append(field) - fields_list.append(f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|") + fields_list.append( + f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}" + f"|{field['options'] or ''}|" + ) - fields.append({ - 'id': 'ctime', - 'name': 'Created Time', - 'type': self.get_field_type_by_id(16), - 'typeId': 16, - 'description': '', - 'options': [] - }) + fields.append( + { + "id": "ctime", + "name": "Created Time", + "type": self.get_field_type_by_id(16), + "typeId": 16, + "description": "", + "options": [], + } + ) fields_list.append("|ctime|Created Time|Date|16|||") - return fields, '\n'.join(fields_list) + return fields, "\n".join(fields_list) def _extract_options(self, control: dict) -> list: options = [] - if control['type'] in [9, 10, 11]: - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) - elif control['type'] in [28, 36]: - itemnames = control['advancedSetting'].get('itemnames') - if itemnames and itemnames.startswith('[{'): + if control["type"] in {9, 10, 11}: + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) + elif control["type"] in {28, 36}: + itemnames = control["advancedSetting"].get("itemnames") + if itemnames and itemnames.startswith("[{"): try: options = json.loads(itemnames) except json.JSONDecodeError: pass - elif control['type'] == 30: - source_type = control['sourceControl']['type'] + elif control["type"] == 30: + source_type = control["sourceControl"]["type"] if source_type not in self._get_ignore_types(): - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) return options - + def _get_ignore_types(self): - return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} \ No newline at end of file + return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 6bf1caa65e..26d7116869 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -8,64 +8,66 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetWorksheetPivotDataTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - x_column_fields = tool_parameters.get('x_column_fields', '') - if not x_column_fields or not x_column_fields.startswith('['): - return self.create_text_message('Invalid parameter Column Fields') - y_row_fields = tool_parameters.get('y_row_fields', '') - if y_row_fields and not y_row_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Row Fields') + return self.create_text_message("Invalid parameter Worksheet ID") + x_column_fields = tool_parameters.get("x_column_fields", "") + if not x_column_fields or not x_column_fields.startswith("["): + return self.create_text_message("Invalid parameter Column Fields") + y_row_fields = tool_parameters.get("y_row_fields", "") + if y_row_fields and not y_row_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Row Fields") elif not y_row_fields: - y_row_fields = '[]' - value_fields = tool_parameters.get('value_fields', '') - if not value_fields or not value_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Value Fields') - - host = tool_parameters.get('host', '') + y_row_fields = "[]" + value_fields = tool_parameters.get("value_fields", "") + if not value_fields or not value_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Value Fields") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/report/getPivotData" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "options": {"showTotal": True}} try: x_column_fields = json.loads(x_column_fields) - payload['columns'] = x_column_fields + payload["columns"] = x_column_fields y_row_fields = json.loads(y_row_fields) - if y_row_fields: payload['rows'] = y_row_fields + if y_row_fields: + payload["rows"] = y_row_fields value_fields = json.loads(value_fields) - payload['values'] = value_fields - sort_fields = tool_parameters.get('sort_fields', '') - if not sort_fields: sort_fields = '[]' + payload["values"] = value_fields + sort_fields = tool_parameters.get("sort_fields", "") + if not sort_fields: + sort_fields = "[]" sort_fields = json.loads(sort_fields) - if sort_fields: payload['options']['sort'] = sort_fields + if sort_fields: + payload["options"]["sort"] = sort_fields res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('status') != 1: + if res_json.get("status") != 1: return self.create_text_message(f"Failed to get the worksheet pivot data. {res_json['msg']}") - - pivot_json = self.generate_pivot_json(res_json['data']) - pivot_table = self.generate_pivot_table(res_json['data']) - result_type = tool_parameters.get('result_type', '') - text = pivot_table if result_type == 'table' else json.dumps(pivot_json, ensure_ascii=False) + + pivot_json = self.generate_pivot_json(res_json["data"]) + pivot_table = self.generate_pivot_table(res_json["data"]) + result_type = tool_parameters.get("result_type", "") + text = pivot_table if result_type == "table" else json.dumps(pivot_json, ensure_ascii=False) return self.create_text_message(text) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet pivot data, request error: {e}") @@ -75,27 +77,31 @@ class GetWorksheetPivotDataTool(BuiltinTool): return self.create_text_message(f"Failed to get the worksheet pivot data, unexpected error: {e}") def generate_pivot_table(self, data: dict[str, Any]) -> str: - columns = data['metadata']['columns'] - rows = data['metadata']['rows'] - values = data['metadata']['values'] + columns = data["metadata"]["columns"] + rows = data["metadata"]["rows"] + values = data["metadata"]["values"] - rows_data = data['data'] + rows_data = data["data"] - header = ([row['displayName'] for row in rows] if rows else []) + [column['displayName'] for column in columns] + [value['displayName'] for value in values] - line = (['---'] * len(rows) if rows else []) + ['---'] * len(columns) + ['--:'] * len(values) + header = ( + ([row["displayName"] for row in rows] if rows else []) + + [column["displayName"] for column in columns] + + [value["displayName"] for value in values] + ) + line = (["---"] * len(rows) if rows else []) + ["---"] * len(columns) + ["--:"] * len(values) table = [header, line] for row in rows_data: - row_data = [self.replace_pipe(row['rows'][r['controlId']]) for r in rows] if rows else [] - row_data.extend([self.replace_pipe(row['columns'][column['controlId']]) for column in columns]) - row_data.extend([self.replace_pipe(str(row['values'][value['controlId']])) for value in values]) + row_data = [self.replace_pipe(row["rows"][r["controlId"]]) for r in rows] if rows else [] + row_data.extend([self.replace_pipe(row["columns"][column["controlId"]]) for column in columns]) + row_data.extend([self.replace_pipe(str(row["values"][value["controlId"]])) for value in values]) table.append(row_data) - return '\n'.join([('|'+'|'.join(row) +'|') for row in table]) - + return "\n".join([("|" + "|".join(row) + "|") for row in table]) + def replace_pipe(self, text: str) -> str: - return text.replace('|', '▏').replace('\n', ' ') - + return text.replace("|", "▏").replace("\n", " ") + def generate_pivot_json(self, data: dict[str, Any]) -> dict: fields = { "x-axis": [ @@ -103,13 +109,14 @@ class GetWorksheetPivotDataTool(BuiltinTool): for column in data["metadata"]["columns"] ], "y-axis": [ - {"fieldId": row["controlId"], "fieldName": row["displayName"]} - for row in data["metadata"]["rows"] - ] if data["metadata"]["rows"] else [], + {"fieldId": row["controlId"], "fieldName": row["displayName"]} for row in data["metadata"]["rows"] + ] + if data["metadata"]["rows"] + else [], "values": [ {"fieldId": value["controlId"], "fieldName": value["displayName"]} for value in data["metadata"]["values"] - ] + ], } # fields = ([ # {"fieldId": row["controlId"], "fieldName": row["displayName"]} @@ -123,8 +130,8 @@ class GetWorksheetPivotDataTool(BuiltinTool): # ] rows = [] for row in data["data"]: - row_data = row["rows"] if row["rows"] else {} + row_data = row["rows"] or {} row_data.update(row["columns"]) row_data.update(row["values"]) rows.append(row_data) - return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} \ No newline at end of file + return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index dddc041cc1..d6ac3688b7 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -9,191 +9,213 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListWorksheetRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') + return self.create_text_message("Invalid parameter App Key") - sign = tool_parameters.get('sign', '') + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') + return self.create_text_message("Invalid parameter Sign") - worksheet_id = tool_parameters.get('worksheet_id', '') + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') + return self.create_text_message("Invalid parameter Worksheet ID") - host = tool_parameters.get('host', '') + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" - + host = f"{host.removesuffix('/')}/api" + url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} - field_ids = tool_parameters.get('field_ids', '') + field_ids = tool_parameters.get("field_ids", "") try: res = httpx.post(url_fields, headers=headers, json=payload, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the worksheet information. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to get the worksheet information. {}".format(res_json["error_msg"]) + ) else: - worksheet_name = res_json['data']['name'] - fields, schema, table_header = self.get_schema(res_json['data']['controls'], field_ids) + worksheet_name = res_json["data"]["name"] + fields, schema, table_header = self.get_schema(res_json["data"]["controls"], field_ids) else: return self.create_text_message( - f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}") + f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to get the worksheet information, something went wrong: {}".format(e)) + return self.create_text_message( + "Failed to get the worksheet information, something went wrong: {}".format(e) + ) if field_ids: - payload['controls'] = [v.strip() for v in field_ids.split(',')] if field_ids else [] - filters = tool_parameters.get('filters', '') + payload["controls"] = [v.strip() for v in field_ids.split(",")] if field_ids else [] + filters = tool_parameters.get("filters", "") if filters: - payload['filters'] = json.loads(filters) - sort_id = tool_parameters.get('sort_id', '') - sort_is_asc = tool_parameters.get('sort_is_asc', False) + payload["filters"] = json.loads(filters) + sort_id = tool_parameters.get("sort_id", "") + sort_is_asc = tool_parameters.get("sort_is_asc", False) if sort_id: - payload['sortId'] = sort_id - payload['isAsc'] = sort_is_asc - limit = tool_parameters.get('limit', 50) - payload['pageSize'] = limit - page_index = tool_parameters.get('page_index', 1) - payload['pageIndex'] = page_index - payload['useControlId'] = True - payload['listType'] = 1 + payload["sortId"] = sort_id + payload["isAsc"] = sort_is_asc + limit = tool_parameters.get("limit", 50) + payload["pageSize"] = limit + page_index = tool_parameters.get("page_index", 1) + payload["pageIndex"] = page_index + payload["useControlId"] = True + payload["listType"] = 1 url = f"{host}/v2/open/worksheet/getFilterRows" try: res = httpx.post(url, headers=headers, json=payload, timeout=90) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the records. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message("Failed to get the records. {}".format(res_json["error_msg"])) else: result = { "fields": fields, "rows": [], "total": res_json.get("data", {}).get("total"), - "payload": {key: payload[key] for key in ['worksheetId', 'controls', 'filters', 'sortId', 'isAsc', 'pageSize', 'pageIndex'] if key in payload} + "payload": { + key: payload[key] + for key in [ + "worksheetId", + "controls", + "filters", + "sortId", + "isAsc", + "pageSize", + "pageIndex", + ] + if key in payload + }, } rows = res_json.get("data", {}).get("rows", []) - result_type = tool_parameters.get('result_type', '') - if not result_type: result_type = 'table' - if result_type == 'json': + result_type = tool_parameters.get("result_type", "") + if not result_type: + result_type = "table" + if result_type == "json": for row in rows: - result['rows'].append(self.get_row_field_value(row, schema)) + result["rows"].append(self.get_row_field_value(row, schema)) return self.create_text_message(json.dumps(result, ensure_ascii=False)) else: result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." - if result['total'] > 0: - result_text += f" The following are {result['total'] if result['total'] < limit else limit} pieces of data presented in a table format:\n\n{table_header}" + if result["total"] > 0: + result_text += ( + f" The following are {min(limit, result['total'])}" + f" pieces of data presented in a table format:\n\n{table_header}" + ) for row in rows: result_values = [] for f in fields: - result_values.append(self.handle_value_type(row[f['fieldId']], schema[f['fieldId']])) - result_text += '\n|'+'|'.join(result_values)+'|' + result_values.append( + self.handle_value_type(row[f["fieldId"]], schema[f["fieldId"]]) + ) + result_text += "\n|" + "|".join(result_values) + "|" return self.create_text_message(result_text) else: return self.create_text_message( - f"Failed to get the records, status code: {res.status_code}, response: {res.text}") + f"Failed to get the records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get the records, something went wrong: {}".format(e)) - def get_row_field_value(self, row: dict, schema: dict): row_value = {"rowid": row["rowid"]} for field in schema: row_value[field] = self.handle_value_type(row[field], schema[field]) return row_value - - def get_schema(self, controls: list, fieldids: str): - allow_fields = {v.strip() for v in fieldids.split(',')} if fieldids else set() + def get_schema(self, controls: list, fieldids: str): + allow_fields = {v.strip() for v in fieldids.split(",")} if fieldids else set() fields = [] schema = {} field_names = [] for control in controls: control_type_id = self.get_real_type_id(control) - if (control_type_id in self._get_ignore_types()) or (allow_fields and not control['controlId'] in allow_fields): + if (control_type_id in self._get_ignore_types()) or ( + allow_fields and control["controlId"] not in allow_fields + ): continue else: - fields.append({'fieldId': control['controlId'], 'fieldName': control['controlName']}) - schema[control['controlId']] = {'typeId': control_type_id, 'options': self.set_option(control)} - field_names.append(control['controlName']) - if (not allow_fields or ('ctime' in allow_fields)): - fields.append({'fieldId': 'ctime', 'fieldName': 'Created Time'}) - schema['ctime'] = {'typeId': 16, 'options': {}} + fields.append({"fieldId": control["controlId"], "fieldName": control["controlName"]}) + schema[control["controlId"]] = {"typeId": control_type_id, "options": self.set_option(control)} + field_names.append(control["controlName"]) + if not allow_fields or ("ctime" in allow_fields): + fields.append({"fieldId": "ctime", "fieldName": "Created Time"}) + schema["ctime"] = {"typeId": 16, "options": {}} field_names.append("Created Time") - fields.append({'fieldId':'rowid', 'fieldName': 'Record Row ID'}) - schema['rowid'] = {'typeId': 2, 'options': {}} + fields.append({"fieldId": "rowid", "fieldName": "Record Row ID"}) + schema["rowid"] = {"typeId": 2, "options": {}} field_names.append("Record Row ID") - return fields, schema, '|'+'|'.join(field_names)+'|\n|'+'---|'*len(field_names) - + return fields, schema, "|" + "|".join(field_names) + "|\n|" + "---|" * len(field_names) + def get_real_type_id(self, control: dict) -> int: - return control['sourceControlType'] if control['type'] == 30 else control['type'] - + return control["sourceControlType"] if control["type"] == 30 else control["type"] + def set_option(self, control: dict) -> dict: options = {} - if control.get('options'): - options = {option['key']: option['value'] for option in control['options']} - elif control.get('advancedSetting', {}).get('itemnames'): + if control.get("options"): + options = {option["key"]: option["value"] for option in control["options"]} + elif control.get("advancedSetting", {}).get("itemnames"): try: - itemnames = json.loads(control['advancedSetting']['itemnames']) - options = {item['key']: item['value'] for item in itemnames} + itemnames = json.loads(control["advancedSetting"]["itemnames"]) + options = {item["key"]: item["value"] for item in itemnames} except json.JSONDecodeError: pass return options def _get_ignore_types(self): return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} - + def handle_value_type(self, value, field): type_id = field.get("typeId") if type_id == 10: value = value if isinstance(value, str) else "、".join(value) - elif type_id in [28, 36]: + elif type_id in {28, 36}: value = field.get("options", {}).get(value, value) - elif type_id in [26, 27, 48, 14]: + elif type_id in {26, 27, 48, 14}: value = self.process_value(value) - elif type_id in [35, 29]: + elif type_id in {35, 29}: value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) - return self.rich_text_to_plain_text(value) if value else '' + return self.rich_text_to_plain_text(value) if value else "" def process_value(self, value): if isinstance(value, str): - if value.startswith("[{\"accountId\""): + if value.startswith('[{"accountId"'): value = json.loads(value) - value = ', '.join([item['fullname'] for item in value]) - elif value.startswith("[{\"departmentId\""): + value = ", ".join([item["fullname"] for item in value]) + elif value.startswith('[{"departmentId"'): value = json.loads(value) - value = '、'.join([item['departmentName'] for item in value]) - elif value.startswith("[{\"organizeId\""): + value = "、".join([item["departmentName"] for item in value]) + elif value.startswith('[{"organizeId"'): value = json.loads(value) - value = '、'.join([item['organizeName'] for item in value]) - elif value.startswith("[{\"file_id\""): - value = '' - elif value == '[]': - value = '' - elif hasattr(value, 'accountId'): - value = value['fullname'] + value = "、".join([item["organizeName"] for item in value]) + elif value.startswith('[{"file_id"') or value == "[]": + value = "" + elif hasattr(value, "accountId"): + value = value["fullname"] return value def parse_cascade_or_associated(self, field, value): - if (field['typeId'] == 35 and value.startswith('[')) or (field['typeId'] == 29 and value.startswith('[{')): + if (field["typeId"] == 35 and value.startswith("[")) or (field["typeId"] == 29 and value.startswith("[{")): value = json.loads(value) - value = value[0]['name'] if len(value) > 0 else '' + value = value[0]["name"] if len(value) > 0 else "" else: - value = '' + value = "" return value def parse_location(self, value): @@ -205,5 +227,5 @@ class ListWorksheetRecordsTool(BuiltinTool): return value def rich_text_to_plain_text(self, rich_text): - text = re.sub(r'<[^>]+>', '', rich_text) if '<' in rich_text else rich_text - return text.replace("|", "▏").replace("\n", " ") \ No newline at end of file + text = re.sub(r"<[^>]+>", "", rich_text) if "<" in rich_text else rich_text + return text.replace("|", "▏").replace("\n", " ") diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 960cbd10ac..4e852c0028 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -8,75 +8,76 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListWorksheetsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Sign") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v1/open/app/get" - result_type = tool_parameters.get('result_type', '') + result_type = tool_parameters.get("result_type", "") if not result_type: - result_type = 'table' + result_type = "table" - headers = { 'Content-Type': 'application/json' } - params = { "appKey": appkey, "sign": sign, } + headers = {"Content-Type": "application/json"} + params = { + "appKey": appkey, + "sign": sign, + } try: res = httpx.get(url, headers=headers, params=params, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to access the application. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to access the application. {}".format(res_json["error_msg"]) + ) else: - if result_type == 'json': + if result_type == "json": worksheets = [] - for section in res_json['data']['sections']: + for section in res_json["data"]["sections"]: worksheets.extend(self._extract_worksheets(section, result_type)) return self.create_text_message(text=json.dumps(worksheets, ensure_ascii=False)) else: - worksheets = '|worksheetId|worksheetName|description|\n|---|---|---|' - for section in res_json['data']['sections']: + worksheets = "|worksheetId|worksheetName|description|\n|---|---|---|" + for section in res_json["data"]["sections"]: worksheets += self._extract_worksheets(section, result_type) return self.create_text_message(worksheets) else: return self.create_text_message( - f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}") + f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list worksheets, something went wrong: {}".format(e)) def _extract_worksheets(self, section, type): items = [] - tables = '' - for item in section.get('items', []): - if item.get('type') == 0 and (not 'notes' in item or item.get('notes') != 'NO'): - if type == 'json': - filtered_item = { - 'id': item['id'], - 'name': item['name'], - 'notes': item.get('notes', '') - } + tables = "" + for item in section.get("items", []): + if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"): + if type == "json": + filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")} items.append(filtered_item) else: tables += f"\n|{item['id']}|{item['name']}|{item.get('notes', '')}|" - for child_section in section.get('childSections', []): - if type == 'json': - items.extend(self._extract_worksheets(child_section, 'json')) + for child_section in section.get("childSections", []): + if type == "json": + items.extend(self._extract_worksheets(child_section, "json")) else: - tables += self._extract_worksheets(child_section, 'table') - - return items if type == 'json' else tables \ No newline at end of file + tables += self._extract_worksheets(child_section, "table") + + return items if type == "json" else tables diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 6ca1b98d90..971f3d37f6 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -8,44 +8,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class UpdateWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Record Row ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/editRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to update the record. {res_json['error_msg']}") return self.create_text_message("Record updated successfully.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index 12e5058cdc..154e15db01 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -10,27 +10,29 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if credentials['api_key'] is None: - credentials['api_key'] = '' + if credentials["api_key"] is None: + credentials["api_key"] = "" else: - result = JinaReaderTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - "url": "https://example.com", - }, - )[0] + result = ( + JinaReaderTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "url": "https://example.com", + }, + )[0] + ) message = json.loads(result.message) - if message['code'] != 200: - raise ToolProviderCredentialValidationError(message['message']) + if message["code"] != 200: + raise ToolProviderCredentialValidationError(message["message"]) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - + def _get_tool_labels(self) -> list[ToolLabelEnum]: - return [ - ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY - ] \ No newline at end of file + return [ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY] diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index cee46cee23..0dd55c6529 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -9,26 +9,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaReaderTool(BuiltinTool): - _jina_reader_endpoint = 'https://r.jina.ai/' + _jina_reader_endpoint = "https://r.jina.ai/" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - url = tool_parameters['url'] + url = tool_parameters["url"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - request_params = tool_parameters.get('request_params') - if request_params is not None and request_params != '': + request_params = tool_parameters.get("request_params") + if request_params is not None and request_params != "": try: request_params = json.loads(request_params) if not isinstance(request_params, dict): @@ -36,40 +35,40 @@ class JinaReaderTool(BuiltinTool): except (json.JSONDecodeError, ValueError) as e: raise ValueError(f"Invalid request_params: {e}") - target_selector = tool_parameters.get('target_selector') - if target_selector is not None and target_selector != '': - headers['X-Target-Selector'] = target_selector + target_selector = tool_parameters.get("target_selector") + if target_selector is not None and target_selector != "": + headers["X-Target-Selector"] = target_selector - wait_for_selector = tool_parameters.get('wait_for_selector') - if wait_for_selector is not None and wait_for_selector != '': - headers['X-Wait-For-Selector'] = wait_for_selector + wait_for_selector = tool_parameters.get("wait_for_selector") + if wait_for_selector is not None and wait_for_selector != "": + headers["X-Wait-For-Selector"] = wait_for_selector - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), - max_retries=max_retries + max_retries=max_retries, ) - if tool_parameters.get('summary', False): + if tool_parameters.get("summary", False): return self.create_text_message(self.summary(user_id, response.text)) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index d4a81cd096..30af6de783 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -8,44 +8,39 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaSearchTool(BuiltinTool): - _jina_search_endpoint = 'https://s.jina.ai/' + _jina_search_endpoint = "https://s.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters['query'] + query = tool_parameters["query"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( - str(URL(self._jina_search_endpoint + query)), - headers=headers, - timeout=(10, 60), - max_retries=max_retries + str(URL(self._jina_search_endpoint + query)), headers=headers, timeout=(10, 60), max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py index 0d018e3ca2..06dabcc9c2 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -6,33 +6,29 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaTokenizerTool(BuiltinTool): - _jina_tokenizer_endpoint = 'https://tokenize.jina.ai/' + _jina_tokenizer_endpoint = "https://tokenize.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> ToolInvokeMessage: - content = tool_parameters['content'] - body = { - "content": content - } + content = tool_parameters["content"] + body = {"content": content} - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('return_chunks', False): - body['return_chunks'] = True - - if tool_parameters.get('return_tokens', False): - body['return_tokens'] = True - - if tokenizer := tool_parameters.get('tokenizer'): - body['tokenizer'] = tokenizer + if tool_parameters.get("return_chunks", False): + body["return_chunks"] = True + + if tool_parameters.get("return_tokens", False): + body["return_tokens"] = True + + if tokenizer := tool_parameters.get("tokenizer"): + body["tokenizer"] = tokenizer response = ssrf_proxy.post( self._jina_tokenizer_endpoint, diff --git a/api/core/tools/provider/builtin/json_process/json_process.py b/api/core/tools/provider/builtin/json_process/json_process.py index f6eed3c628..10746210b5 100644 --- a/api/core/tools/provider/builtin/json_process/json_process.py +++ b/api/core/tools/provider/builtin/json_process/json_process.py @@ -8,10 +8,9 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - JSONParseTool().invoke(user_id='', - tool_parameters={ - 'content': '{"name": "John", "age": 30, "city": "New York"}', - 'json_filter': '$.name' - }) + JSONParseTool().invoke( + user_id="", + tool_parameters={"content": '{"name": "John", "age": 30, "city": "New York"}', "json_filter": "$.name"}, + ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index 1b49cfe2f3..fcab3d71a9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -8,34 +8,35 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONDeleteTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the JSON delete tool """ # Get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # Get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._delete(content, query, ensure_ascii) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to delete JSON content: {str(e)}') + return self.create_text_message(f"Failed to delete JSON content: {str(e)}") def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str: try: input_data = json.loads(origin_json) - expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $ + expr = parse("$." + query.lstrip("$.")) # Ensure query path starts with $ matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 48d1bdcab4..793c74e5f9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -8,46 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get new value - new_value = tool_parameters.get('new_value', '') + new_value = tool_parameters.get("new_value", "") if not new_value: - return self.create_text_message('Invalid parameter new_value') + return self.create_text_message("Invalid parameter new_value") # get insert position - index = tool_parameters.get('index') + index = tool_parameters.get("index") # get create path - create_path = tool_parameters.get('create_path', False) + create_path = tool_parameters.get("create_path", False) # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._insert(content, query, new_value, ensure_ascii, value_decode, index, create_path) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to insert JSON content') + return self.create_text_message("Failed to insert JSON content") - def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False): + def _insert( + self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False + ): try: input_data = json.loads(origin_json) expr = parse(query) @@ -61,13 +64,13 @@ class JSONParseTool(BuiltinTool): if not matches and create_path: # create new path - path_parts = query.strip('$').strip('.').split('.') + path_parts = query.strip("$").strip(".").split(".") current = input_data for i, part in enumerate(path_parts): - if '[' in part and ']' in part: + if "[" in part and "]" in part: # process array index - array_name, index = part.split('[') - index = int(index.rstrip(']')) + array_name, index = part.split("[") + index = int(index.rstrip("]")) if array_name not in current: current[array_name] = [] while len(current[array_name]) <= index: diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index ecd39113ae..37cae40153 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -8,29 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get json filter - json_filter = tool_parameters.get('json_filter', '') + json_filter = tool_parameters.get("json_filter", "") if not json_filter: - return self.create_text_message('Invalid parameter json_filter') + return self.create_text_message("Invalid parameter json_filter") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._extract(content, json_filter, ensure_ascii) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to extract JSON content') + return self.create_text_message("Failed to extract JSON content") # Extract data from JSON content def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str: diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index b19198aa93..383825c2d0 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -8,55 +8,60 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONReplaceTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get replace value - replace_value = tool_parameters.get('replace_value', '') + replace_value = tool_parameters.get("replace_value", "") if not replace_value: - return self.create_text_message('Invalid parameter replace_value') + return self.create_text_message("Invalid parameter replace_value") # get replace model - replace_model = tool_parameters.get('replace_model', '') + replace_model = tool_parameters.get("replace_model", "") if not replace_model: - return self.create_text_message('Invalid parameter replace_model') + return self.create_text_message("Invalid parameter replace_model") # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: - if replace_model == 'pattern': + if replace_model == "pattern": # get replace pattern - replace_pattern = tool_parameters.get('replace_pattern', '') + replace_pattern = tool_parameters.get("replace_pattern", "") if not replace_pattern: - return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii, value_decode) - elif replace_model == 'key': + return self.create_text_message("Invalid parameter replace_pattern") + result = self._replace_pattern( + content, query, replace_pattern, replace_value, ensure_ascii, value_decode + ) + elif replace_model == "key": result = self._replace_key(content, query, replace_value, ensure_ascii) - elif replace_model == 'value': + elif replace_model == "value": result = self._replace_value(content, query, replace_value, ensure_ascii, value_decode) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to replace JSON content') + return self.create_text_message("Failed to replace JSON content") # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_pattern( + self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) @@ -102,7 +107,9 @@ class JSONReplaceTool(BuiltinTool): return str(e) # Replace value - def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_value( + self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index bac6576797..50db74dd9e 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -13,7 +13,7 @@ class Judge0CEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "source_code": "print('hello world')", "language_id": 71, @@ -21,4 +21,3 @@ class Judge0CEProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py index 6031687c03..b8d654ff63 100644 --- a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py +++ b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py @@ -9,11 +9,13 @@ from core.tools.tool.builtin_tool import BuiltinTool class ExecuteCodeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ - api_key = self.runtime.credentials['X-RapidAPI-Key'] + api_key = self.runtime.credentials["X-RapidAPI-Key"] url = "https://judge0-ce.p.rapidapi.com/submissions" @@ -22,15 +24,15 @@ class ExecuteCodeTool(BuiltinTool): headers = { "Content-Type": "application/json", "X-RapidAPI-Key": api_key, - "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com" + "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com", } payload = { - "language_id": tool_parameters['language_id'], - "source_code": tool_parameters['source_code'], - "stdin": tool_parameters.get('stdin', ''), - "expected_output": tool_parameters.get('expected_output', ''), - "additional_files": tool_parameters.get('additional_files', ''), + "language_id": tool_parameters["language_id"], + "source_code": tool_parameters["source_code"], + "stdin": tool_parameters.get("stdin", ""), + "expected_output": tool_parameters.get("expected_output", ""), + "additional_files": tool_parameters.get("additional_files", ""), } response = post(url, data=json.dumps(payload), headers=headers, params=querystring) @@ -38,22 +40,22 @@ class ExecuteCodeTool(BuiltinTool): if response.status_code != 201: raise Exception(response.text) - token = response.json()['token'] + token = response.json()["token"] url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}" - headers = { - "X-RapidAPI-Key": api_key - } - + headers = {"X-RapidAPI-Key": api_key} + response = requests.get(url, headers=headers) if response.status_code == 200: result = response.json() - return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n" - f"stderr: {result.get('stderr', '')}\n" - f"compile_output: {result.get('compile_output', '')}\n" - f"message: {result.get('message', '')}\n" - f"status: {result['status']['description']}\n" - f"time: {result.get('time', '')} seconds\n" - f"memory: {result.get('memory', '')} bytes") + return self.create_text_message( + text=f"stdout: {result.get('stdout', '')}\n" + f"stderr: {result.get('stderr', '')}\n" + f"compile_output: {result.get('compile_output', '')}\n" + f"message: {result.get('message', '')}\n" + f"status: {result['status']['description']}\n" + f"time: {result.get('time', '')} seconds\n" + f"memory: {result.get('memory', '')} bytes" + ) else: - return self.create_text_message(text=f"Error retrieving submission details: {response.text}") \ No newline at end of file + return self.create_text_message(text=f"Error retrieving submission details: {response.text}") diff --git a/api/core/tools/provider/builtin/maths/maths.py b/api/core/tools/provider/builtin/maths/maths.py index 7226a5c168..d4b449ec87 100644 --- a/api/core/tools/provider/builtin/maths/maths.py +++ b/api/core/tools/provider/builtin/maths/maths.py @@ -9,9 +9,9 @@ class MathsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: EvaluateExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'expression': '1+(2+3)*4', + "expression": "1+(2+3)*4", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index bf73ed6918..0c5b5e41cb 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -8,22 +8,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class EvaluateExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - expression = tool_parameters.get('expression', '').strip() + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = ne.evaluate(expression) result_str = str(result) except Exception as e: - logging.exception(f'Error evaluating expression: {expression}') - return self.create_text_message(f'Invalid expression: {expression}, error: {str(e)}') - return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') \ No newline at end of file + logging.exception(f"Error evaluating expression: {expression}") + return self.create_text_message(f"Invalid expression: {expression}, error: {str(e)}") + return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') diff --git a/api/core/tools/provider/builtin/nominatim/nominatim.py b/api/core/tools/provider/builtin/nominatim/nominatim.py index b6f29b5feb..5a24bed750 100644 --- a/api/core/tools/provider/builtin/nominatim/nominatim.py +++ b/api/core/tools/provider/builtin/nominatim/nominatim.py @@ -8,16 +8,20 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class NominatimProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NominatimSearchTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'query': 'London', - 'limit': 1, - }, + result = ( + NominatimSearchTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "query": "London", + "limit": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py index e21ce14f54..ffa8ad0fcc 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py @@ -8,40 +8,33 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimLookupTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - osm_ids = tool_parameters.get('osm_ids', '') - - if not osm_ids: - return self.create_text_message('Please provide OSM IDs') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + osm_ids = tool_parameters.get("osm_ids", "") - params = { - 'osm_ids': osm_ids, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'lookup', params) + if not osm_ids: + return self.create_text_message("Please provide OSM IDs") + + params = {"osm_ids": osm_ids, "format": "json", "addressdetails": 1} + + return self._make_request(user_id, "lookup", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py index 438d5219e9..f46691e1a3 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py @@ -8,42 +8,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimReverseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - lat = tool_parameters.get('lat') - lon = tool_parameters.get('lon') - - if lat is None or lon is None: - return self.create_text_message('Please provide both latitude and longitude') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + lat = tool_parameters.get("lat") + lon = tool_parameters.get("lon") - params = { - 'lat': lat, - 'lon': lon, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'reverse', params) + if lat is None or lon is None: + return self.create_text_message("Please provide both latitude and longitude") + + params = {"lat": lat, "lon": lon, "format": "json", "addressdetails": 1} + + return self._make_request(user_id, "reverse", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py index 983cbc0e34..34851d86dc 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py @@ -8,42 +8,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimSearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query', '') - limit = tool_parameters.get('limit', 10) - - if not query: - return self.create_text_message('Please input a search query') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query", "") + limit = tool_parameters.get("limit", 10) - params = { - 'q': query, - 'format': 'json', - 'limit': limit, - 'addressdetails': 1 - } - - return self._make_request(user_id, 'search', params) + if not query: + return self.create_text_message("Please input a search query") + + params = {"q": query, "format": "json", "limit": limit, "addressdetails": 1} + + return self._make_request(user_id, "search", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index b753be4791..762e158459 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -12,10 +12,10 @@ class NovitaAiToolBase: if not loras_str: return [] - loras_ori_list = lora_str.strip().split(';') + loras_ori_list = lora_str.strip().split(";") result_list = [] for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') + lora_info = lora_str.strip().split(",") lora = Txt2ImgV3LoRA( model_name=lora_info[0].strip(), strength=float(lora_info[1]), @@ -28,43 +28,39 @@ class NovitaAiToolBase: if not embeddings_str: return [] - embeddings_ori_list = embeddings_str.strip().split(';') + embeddings_ori_list = embeddings_str.strip().split(";") result_list = [] for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) + embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip()) result_list.append(embedding) return result_list def _extract_hires_fix(self, hires_fix_str: str): - hires_fix_info = hires_fix_str.strip().split(',') - if 'upscaler' in hires_fix_info: + hires_fix_info = hires_fix_str.strip().split(",") + if "upscaler" in hires_fix_info: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() + upscaler=hires_fix_info[3].strip(), ) else: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) + strength=float(hires_fix_info[2]), ) return hires_fix def _extract_refiner(self, switch_at: str): - refiner = Txt2ImgV3Refiner( - switch_at=float(switch_at) - ) + refiner = Txt2ImgV3Refiner(switch_at=float(switch_at)) return refiner def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: """ - is hit nsfw + is hit nsfw """ if image.nsfw_detection_result is None: return False diff --git a/api/core/tools/provider/builtin/novitaai/novitaai.py b/api/core/tools/provider/builtin/novitaai/novitaai.py index 1e7d9757c3..d5e32eff29 100644 --- a/api/core/tools/provider/builtin/novitaai/novitaai.py +++ b/api/core/tools/provider/builtin/novitaai/novitaai.py @@ -8,23 +8,27 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class NovitaAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NovitaAiTxt2ImgTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors', - 'prompt': 'a futuristic city with flying cars', - 'negative_prompt': '', - 'width': 128, - 'height': 128, - 'image_num': 1, - 'guidance_scale': 7.5, - 'seed': -1, - 'steps': 1, - }, + result = ( + NovitaAiTxt2ImgTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors", + "prompt": "a futuristic city with flying cars", + "negative_prompt": "", + "width": 128, + "height": 128, + "image_num": 1, + "guidance_scale": 7.5, + "seed": -1, + "steps": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index e63c891957..0b4f2edff3 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -12,17 +12,18 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiCreateTileTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -30,21 +31,23 @@ class NovitaAiCreateTileTool(BuiltinTool): results = [] results.append( - self.create_blob_message(blob=b64decode(client_result.image_file), - meta={'mime_type': f'image/{client_result.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(client_result.image_file), + meta={"mime_type": f"image/{client_result.image_type}"}, + save_as=self.VariableKey.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index ec2927675e..a200ee8123 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -12,127 +12,137 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiModelQueryTool(BuiltinTool): - _model_query_endpoint = 'https://api.novita.ai/v3/model' + _model_query_endpoint = "https://api.novita.ai/v3/model" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer " + api_key - } + api_key = self.runtime.credentials.get("api_key") + headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key} params = self._process_parameters(tool_parameters) - result_type = params.get('result_type') - del params['result_type'] + result_type = params.get("result_type") + del params["result_type"] models_data = self._query_models( models_data=[], headers=headers, params=params, - recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True + recursive=result_type not in {"first sd_name", "first name sd_name pair"}, ) - result_str = '' - if result_type == 'first sd_name': - result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else '' - elif result_type == 'first name sd_name pair': - result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else '' - elif result_type == 'sd_name array': - sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else [] + result_str = "" + if result_type == "first sd_name": + result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else "" + elif result_type == "first name sd_name pair": + result_str = ( + json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]}) + if len(models_data) > 0 + else "" + ) + elif result_type == "sd_name array": + sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(sd_name_array) - elif result_type == 'name array': - name_array = [model['name'] for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name array": + name_array = [model["name"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(name_array) - elif result_type == 'name sd_name pair array': - name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} - for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name sd_name pair array": + name_sd_name_pair_array = ( + [{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data] + if len(models_data) > 0 + else [] + ) result_str = json.dumps(name_sd_name_pair_array) - elif result_type == 'whole info array': + elif result_type == "whole info array": result_str = json.dumps(models_data) else: raise NotImplementedError return self.create_text_message(result_str) - def _query_models(self, models_data: list, headers: dict[str, Any], - params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list: + def _query_models( + self, + models_data: list, + headers: dict[str, Any], + params: dict[str, Any], + pagination_cursor: str = "", + recursive: bool = True, + ) -> list: """ - query models + query models """ inside_params = deepcopy(params) - if pagination_cursor != '': - inside_params['pagination.cursor'] = pagination_cursor + if pagination_cursor != "": + inside_params["pagination.cursor"] = pagination_cursor response = ssrf_proxy.get( - url=str(URL(self._model_query_endpoint)), - headers=headers, - params=params, - timeout=(10, 60) + url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60) ) res_data = response.json() - models_data.extend(res_data['models']) + models_data.extend(res_data["models"]) - res_data_len = len(res_data['models']) - if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False: + res_data_len = len(res_data["models"]) + if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False: # deduplicate df = DataFrame.from_dict(models_data) - df_unique = df.drop_duplicates(subset=['id']) - models_data = df_unique.to_dict('records') + df_unique = df.drop_duplicates(subset=["id"]) + models_data = df_unique.to_dict("records") return models_data return self._query_models( models_data=models_data, headers=headers, params=inside_params, - pagination_cursor=res_data['pagination']['next_cursor'] + pagination_cursor=res_data["pagination"]["next_cursor"], ) def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ process_parameters = deepcopy(parameters) res_parameters = {} # delete none or empty - keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""] for k in keys_to_delete: del process_parameters[k] - if 'query' in process_parameters and process_parameters.get('query') != 'unspecified': - res_parameters['filter.query'] = process_parameters['query'] + if "query" in process_parameters and process_parameters.get("query") != "unspecified": + res_parameters["filter.query"] = process_parameters["query"] - if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified': - res_parameters['filter.visibility'] = process_parameters['visibility'] + if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified": + res_parameters["filter.visibility"] = process_parameters["visibility"] - if 'source' in process_parameters and process_parameters.get('source') != 'unspecified': - res_parameters['filter.source'] = process_parameters['source'] + if "source" in process_parameters and process_parameters.get("source") != "unspecified": + res_parameters["filter.source"] = process_parameters["source"] - if 'type' in process_parameters and process_parameters.get('type') != 'unspecified': - res_parameters['filter.types'] = process_parameters['type'] + if "type" in process_parameters and process_parameters.get("type") != "unspecified": + res_parameters["filter.types"] = process_parameters["type"] - if 'is_sdxl' in process_parameters: - if process_parameters['is_sdxl'] == 'true': - res_parameters['filter.is_sdxl'] = True - elif process_parameters['is_sdxl'] == 'false': - res_parameters['filter.is_sdxl'] = False + if "is_sdxl" in process_parameters: + if process_parameters["is_sdxl"] == "true": + res_parameters["filter.is_sdxl"] = True + elif process_parameters["is_sdxl"] == "false": + res_parameters["filter.is_sdxl"] = False - res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name') + res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name") - res_parameters['pagination.limit'] = 1 \ - if res_parameters.get('result_type') == 'first sd_name' \ - or res_parameters.get('result_type') == 'first name sd_name pair'\ + res_parameters["pagination.limit"] = ( + 1 + if res_parameters.get("result_type") == "first sd_name" + or res_parameters.get("result_type") == "first name sd_name pair" else 100 + ) return res_parameters diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 5fef3d2da7..9c61eab9f9 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -13,17 +13,18 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -32,56 +33,58 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): results = [] for image_encoded, image in zip(client_result.images_encoded, client_result.images): if self._is_hit_nsfw_detection(image, 0.8): - results = self.create_text_message(text='NSFW detected!') + results = self.create_text_message(text="NSFW detected!") break results.append( - self.create_blob_message(blob=b64decode(image_encoded), - meta={'mime_type': f'image/{image.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(image_encoded), + meta={"mime_type": f"image/{image.image_type}"}, + save_as=self.VariableKey.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] - if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0: - del res_parameters['clip_skip'] + if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0: + del res_parameters["clip_skip"] - if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0: - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0: + del res_parameters["refiner_switch_at"] - if 'enabled_enterprise_plan' in res_parameters: - res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']} - del res_parameters['enabled_enterprise_plan'] + if "enabled_enterprise_plan" in res_parameters: + res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]} + del res_parameters["enabled_enterprise_plan"] - if 'nsfw_detection_level' in res_parameters: - res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level']) + if "nsfw_detection_level" in res_parameters: + res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"]) # process loras - if 'loras' in res_parameters: - res_parameters['loras'] = self._extract_loras(res_parameters.get('loras')) + if "loras" in res_parameters: + res_parameters["loras"] = self._extract_loras(res_parameters.get("loras")) # process embeddings - if 'embeddings' in res_parameters: - res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings')) + if "embeddings" in res_parameters: + res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings")) # process hires_fix - if 'hires_fix' in res_parameters: - res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix')) + if "hires_fix" in res_parameters: + res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix")) # process refiner - if 'refiner_switch_at' in res_parameters: - res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at')) - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters: + res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at")) + del res_parameters["refiner_switch_at"] return res_parameters diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py index 42f321e919..b8e5ed24d6 100644 --- a/api/core/tools/provider/builtin/onebot/onebot.py +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -5,8 +5,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class OneBotProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - if not credentials.get("ob11_http_url"): - raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') + raise ToolProviderCredentialValidationError("OneBot HTTP URL is required.") diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py index 2a1a9f86de..9c95bbc2ae 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -11,54 +11,29 @@ class SendGroupMsg(BuiltinTool): """OneBot v11 Tool: Send Group Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_group_id = tool_parameters.get('group_id', '') - - message = tool_parameters.get('message', '') + send_group_id = tool_parameters.get("group_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_group_msg" resp = requests.post( url, - json={ - 'group_id': send_group_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"group_id": send_group_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send group message: {resp.text}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {resp.text}"}) - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send group message: {e}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {e}"}) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py index 8ef4d72ab6..1174c7f07d 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -11,54 +11,29 @@ class SendPrivateMsg(BuiltinTool): """OneBot v11 Tool: Send Private Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_user_id = tool_parameters.get('user_id', '') - - message = tool_parameters.get('message', '') + send_user_id = tool_parameters.get("user_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_private_msg" resp = requests.post( url, - json={ - 'user_id': send_user_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"user_id": send_user_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send private message: {resp.text}' - } - ) - - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"error": f"Failed to send private message: {resp.text}"}) + + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send private message: {e}' - } - ) \ No newline at end of file + return self.create_json_message({"error": f"Failed to send private message: {e}"}) diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py index a2827177a3..9e40249aba 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.py +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -5,7 +5,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): - url = "https://api.openweathermap.org/data/2.5/weather" params = {"q": city, "appid": api_key, "units": units, "lang": language} @@ -16,21 +15,15 @@ class OpenweatherProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: if "api_key" not in credentials or not credentials.get("api_key"): - raise ToolProviderCredentialValidationError( - "Open weather API key is required." - ) + raise ToolProviderCredentialValidationError("Open weather API key is required.") apikey = credentials.get("api_key") try: response = query_weather(api_key=apikey) if response.status_code == 200: pass else: - raise ToolProviderCredentialValidationError( - (response.json()).get("info") - ) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: - raise ToolProviderCredentialValidationError( - "Open weather API Key is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("Open weather API Key is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py index d6c49a230f..ed4ec487fa 100644 --- a/api/core/tools/provider/builtin/openweather/tools/weather.py +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -17,10 +17,7 @@ class OpenweatherTool(BuiltinTool): city = tool_parameters.get("city", "") if not city: return self.create_text_message("Please tell me your city") - if ( - "api_key" not in self.runtime.credentials - or not self.runtime.credentials.get("api_key") - ): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("OpenWeather API key is required.") units = tool_parameters.get("units", "metric") @@ -39,12 +36,9 @@ class OpenweatherTool(BuiltinTool): response = requests.get(url, params=params) if response.status_code == 200: - data = response.json() return self.create_text_message( - self.summary( - user_id=user_id, content=json.dumps(data, ensure_ascii=False) - ) + self.summary(user_id=user_id, content=json.dumps(data, ensure_ascii=False)) ) else: error_message = { @@ -55,6 +49,4 @@ class OpenweatherTool(BuiltinTool): return json.dumps(error_message) except Exception as e: - return self.create_text_message( - "Openweather API Key is invalid. {}".format(e) - ) + return self.create_text_message("Openweather API Key is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py index ff91edf18d..80518853fb 100644 --- a/api/core/tools/provider/builtin/perplexity/perplexity.py +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -11,34 +11,26 @@ class PerplexityProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: headers = { "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + payload = { "model": "llama-3.1-sonar-small-128k-online", "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello" - } + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, ], "max_tokens": 5, "temperature": 0.1, "top_p": 0.9, - "stream": False + "stream": False, } try: response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) response.raise_for_status() except requests.RequestException as e: - raise ToolProviderCredentialValidationError( - f"Failed to validate Perplexity API key: {str(e)}" - ) + raise ToolProviderCredentialValidationError(f"Failed to validate Perplexity API key: {str(e)}") if response.status_code != 200: raise ToolProviderCredentialValidationError( diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py index 5b1a263f9b..5ed4b9ca99 100644 --- a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -8,65 +8,60 @@ from core.tools.tool.builtin_tool import BuiltinTool PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + class PerplexityAITool(BuiltinTool): def _parse_response(self, response: dict) -> dict: """Parse the response from Perplexity AI API""" - if 'choices' in response and len(response['choices']) > 0: - message = response['choices'][0]['message'] + if "choices" in response and len(response["choices"]) > 0: + message = response["choices"][0]["message"] return { - 'content': message.get('content', ''), - 'role': message.get('role', ''), - 'citations': response.get('citations', []) + "content": message.get("content", ""), + "role": message.get("role", ""), + "citations": response.get("citations", []), } else: - return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []} + return {"content": "Unable to get a valid response", "role": "assistant", "citations": []} - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: headers = { "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + payload = { - "model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'), + "model": tool_parameters.get("model", "llama-3.1-sonar-small-128k-online"), "messages": [ - { - "role": "system", - "content": "Be precise and concise." - }, - { - "role": "user", - "content": tool_parameters['query'] - } + {"role": "system", "content": "Be precise and concise."}, + {"role": "user", "content": tool_parameters["query"]}, ], - "max_tokens": tool_parameters.get('max_tokens', 4096), - "temperature": tool_parameters.get('temperature', 0.7), - "top_p": tool_parameters.get('top_p', 1), - "top_k": tool_parameters.get('top_k', 5), - "presence_penalty": tool_parameters.get('presence_penalty', 0), - "frequency_penalty": tool_parameters.get('frequency_penalty', 1), - "stream": False + "max_tokens": tool_parameters.get("max_tokens", 4096), + "temperature": tool_parameters.get("temperature", 0.7), + "top_p": tool_parameters.get("top_p", 1), + "top_k": tool_parameters.get("top_k", 5), + "presence_penalty": tool_parameters.get("presence_penalty", 0), + "frequency_penalty": tool_parameters.get("frequency_penalty", 1), + "stream": False, } - - if 'search_recency_filter' in tool_parameters: - payload['search_recency_filter'] = tool_parameters['search_recency_filter'] - if 'return_citations' in tool_parameters: - payload['return_citations'] = tool_parameters['return_citations'] - if 'search_domain_filter' in tool_parameters: - if isinstance(tool_parameters['search_domain_filter'], str): - payload['search_domain_filter'] = [tool_parameters['search_domain_filter']] - elif isinstance(tool_parameters['search_domain_filter'], list): - payload['search_domain_filter'] = tool_parameters['search_domain_filter'] - + + if "search_recency_filter" in tool_parameters: + payload["search_recency_filter"] = tool_parameters["search_recency_filter"] + if "return_citations" in tool_parameters: + payload["return_citations"] = tool_parameters["return_citations"] + if "search_domain_filter" in tool_parameters: + if isinstance(tool_parameters["search_domain_filter"], str): + payload["search_domain_filter"] = [tool_parameters["search_domain_filter"]] + elif isinstance(tool_parameters["search_domain_filter"], list): + payload["search_domain_filter"] = tool_parameters["search_domain_filter"] response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) - + return [ self.create_json_message(valuable_res), - self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)) + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)), ] diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 05cd171b87..ea3a477c30 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -11,11 +11,10 @@ class PubMedProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 58811d65e6..3a4f374ea0 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -51,17 +51,12 @@ class PubMedAPIWrapper(BaseModel): try: # Retrieve the top-k results for the query docs = [ - f"Published: {result['pub_date']}\nTitle: {result['title']}\n" - f"Summary: {result['summary']}" + f"Published: {result['pub_date']}\nTitle: {result['title']}\nSummary: {result['summary']}" for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) ] # Join the results and limit the character count - return ( - "\n\n".join(docs)[:self.doc_content_chars_max] - if docs - else "No good PubMed Result was found" - ) + return "\n\n".join(docs)[: self.doc_content_chars_max] if docs else "No good PubMed Result was found" except Exception as ex: return f"PubMed exception: {ex}" @@ -91,13 +86,7 @@ class PubMedAPIWrapper(BaseModel): return articles def retrieve_article(self, uid: str, webenv: str) -> dict: - url = ( - self.base_url_efetch - + "db=pubmed&retmode=xml&id=" - + uid - + "&webenv=" - + webenv - ) + url = self.base_url_efetch + "db=pubmed&retmode=xml&id=" + uid + "&webenv=" + webenv retry = 0 while True: @@ -108,10 +97,7 @@ class PubMedAPIWrapper(BaseModel): if e.code == 429 and retry < self.max_retry: # Too Many Requests error # wait for an exponentially increasing amount of time - print( - f"Too Many Requests, " - f"waiting for {self.sleep_time:.2f} seconds..." - ) + print(f"Too Many Requests, waiting for {self.sleep_time:.2f} seconds...") time.sleep(self.sleep_time) self.sleep_time *= 2 retry += 1 @@ -125,27 +111,21 @@ class PubMedAPIWrapper(BaseModel): if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - title = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + title = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get abstract abstract = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - abstract = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + abstract = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get publication date pub_date = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - pub_date = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + pub_date = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Return article as dictionary article = { @@ -182,6 +162,7 @@ class PubmedQueryRun(BaseModel): class PubMedInput(BaseModel): query: str = Field(..., description="Search query.") + class PubMedSearchTool(BuiltinTool): """ Tool for performing a search using PubMed search engine. @@ -198,14 +179,13 @@ class PubMedSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = PubmedQueryRun(args_schema=PubMedInput) result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py index 9fa7d01265..8466b9a26b 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.py +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -8,9 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class QRCodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - QRCodeGeneratorTool().invoke(user_id='', - tool_parameters={ - 'content': 'Dify 123 😊' - }) + QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 5eede98f5e..d8ca20bde6 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -13,43 +13,44 @@ from core.tools.tool.builtin_tool import BuiltinTool class QRCodeGeneratorTool(BuiltinTool): error_correction_levels: dict[str, int] = { - 'L': ERROR_CORRECT_L, # <=7% - 'M': ERROR_CORRECT_M, # <=15% - 'Q': ERROR_CORRECT_Q, # <=25% - 'H': ERROR_CORRECT_H, # <=30% + "L": ERROR_CORRECT_L, # <=7% + "M": ERROR_CORRECT_M, # <=15% + "Q": ERROR_CORRECT_Q, # <=25% + "H": ERROR_CORRECT_H, # <=30% } - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get text content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get border size - border = tool_parameters.get('border', 0) + border = tool_parameters.get("border", 0) if border < 0 or border > 100: - return self.create_text_message('Invalid parameter border') + return self.create_text_message("Invalid parameter border") # get error_correction - error_correction = tool_parameters.get('error_correction', '') - if error_correction not in self.error_correction_levels.keys(): - return self.create_text_message('Invalid parameter error_correction') + error_correction = tool_parameters.get("error_correction", "") + if error_correction not in self.error_correction_levels: + return self.create_text_message("Invalid parameter error_correction") try: image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) - return self.create_blob_message(blob=image_bytes, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_blob_message( + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) except Exception: - logging.exception(f'Failed to generate QR code for content: {content}') - return self.create_text_message('Failed to generate QR code') + logging.exception(f"Failed to generate QR code for content: {content}") + return self.create_text_message("Failed to generate QR code") def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage: qr = QRCode( diff --git a/api/core/tools/provider/builtin/regex/regex.py b/api/core/tools/provider/builtin/regex/regex.py index d38ae1b292..c498105979 100644 --- a/api/core/tools/provider/builtin/regex/regex.py +++ b/api/core/tools/provider/builtin/regex/regex.py @@ -9,10 +9,10 @@ class RegexProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: RegexExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'content': '1+(2+3)*4', - 'expression': r'(\d+)', + "content": "1+(2+3)*4", + "expression": r"(\d+)", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.py b/api/core/tools/provider/builtin/regex/tools/regex_extract.py index 5d8f013d0d..786b469404 100644 --- a/api/core/tools/provider/builtin/regex/tools/regex_extract.py +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.py @@ -6,22 +6,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class RegexExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - content = tool_parameters.get('content', '').strip() + content = tool_parameters.get("content", "").strip() if not content: - return self.create_text_message('Invalid content') - expression = tool_parameters.get('expression', '').strip() + return self.create_text_message("Invalid content") + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = re.findall(expression, content) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to extract result, error: {str(e)}') \ No newline at end of file + return self.create_text_message(f"Failed to extract result, error: {str(e)}") diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py index 6fa4f05acd..109bba8b2d 100644 --- a/api/core/tools/provider/builtin/searchapi/searchapi.py +++ b/api/core/tools/provider/builtin/searchapi/searchapi.py @@ -13,11 +13,8 @@ class SearchAPIProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearchApi dify", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "SearchApi dify", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index dd780aeadc..17e2978194 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -37,42 +38,45 @@ class SearchAPI: return { "engine": "google", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + if "answer_box" in res and "answer" in res["answer_box"]: toret += res["answer_box"]["answer"] + "\n" - if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + if "answer_box" in res and "snippet" in res["answer_box"]: toret += res["answer_box"]["snippet"] + "\n" - if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): + if "knowledge_graph" in res and "description" in res["knowledge_graph"]: toret += res["knowledge_graph"]["description"] + "\n" - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys(): - if "title" in res["answer_box"]["organic_result"].keys(): - toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n" - elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys(): + if "answer_box" in res and "organic_result" in res["answer_box"]: + if "title" in res["answer_box"]["organic_result"]: + toret = ( + f"[{res['answer_box']['organic_result']['title']}]" + f"({res['answer_box']['organic_result']['link']})\n" + ) + elif "organic_results" in res and "link" in res["organic_results"][0]: toret = "" for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys(): + elif "related_questions" in res and "link" in res["related_questions"][0]: toret = "" for item in res["related_questions"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys(): + elif "related_searches" in res and "link" in res["related_searches"][0]: toret = "" for item in res["related_searches"]: toret += f"[{item['title']}]({item['link']})\n" @@ -80,25 +84,29 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 81c67c51a9..c478bc108b 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -37,41 +38,52 @@ class SearchAPI: return { "engine": "google_jobs", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): + if "jobs" in res and "title" in res["jobs"][0]: for item in res["jobs"]: - toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n" + toret += ( + "title: " + + item["title"] + + "\n" + + "company_name: " + + item["company_name"] + + "content: " + + item["description"] + + "\n" + ) if toret == "": toret = "No good search result found" elif type == "link": - if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys(): + if "jobs" in res and "apply_link" in res["jobs"][0]: for item in res["jobs"]: toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n" else: toret = "No good search result found" return toret + class GoogleJobsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] is_remote = tool_parameters.get("is_remote") google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") @@ -80,9 +92,11 @@ class GoogleJobsTool(BuiltinTool): ltype = 1 if is_remote else None - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 5d2657dddd..562bc01964 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -37,56 +38,60 @@ class SearchAPI: return { "engine": "google_news", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" - if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + if "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys(): + if "organic_results" in res and "title" in res["organic_results"][0]: for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + elif "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += f"[{item['title']}]({item['link']})\n" else: toret = "No good search result found" return toret + class GoogleNewsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 6345b33801..1867cf7be7 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -36,18 +37,18 @@ class SearchAPI: return { "engine": "youtube_transcripts", "video_id": video_id, - "lang": language if language else "en", - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + "lang": language or "en", + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" - if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys(): + if "transcripts" in res and "text" in res["transcripts"][0]: for item in res["transcripts"]: toret += item["text"] + " " if toret == "": @@ -55,18 +56,20 @@ class SearchAPI: return toret + class YoutubeTranscriptsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - video_id = tool_parameters['video_id'] - language = tool_parameters.get('language', "en") + video_id = tool_parameters["video_id"] + language = tool_parameters.get("language", "en") - api_key = self.runtime.credentials['searchapi_api_key'] + api_key = self.runtime.credentials["searchapi_api_key"] result = SearchAPI(api_key).run(video_id, language=language) return self.create_text_message(text=result) diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index ab354003e6..b7bbcc60b1 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -13,12 +13,8 @@ class SearXNGProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearXNG", - "limit": 1, - "search_type": "general" - }, + user_id="", + tool_parameters={"query": "SearXNG", "limit": 1, "search_type": "general"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index dc835a8e8c..c5e339a108 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -23,18 +23,21 @@ class SearXNGSearchTool(BuiltinTool): ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url') + host = self.runtime.credentials.get("searxng_base_url") if not host: - raise Exception('SearXNG api is required') + raise Exception("SearXNG api is required") - response = requests.get(host, params={ - "q": tool_parameters.get('query'), - "format": "json", - "categories": tool_parameters.get('search_type', 'general') - }) + response = requests.get( + host, + params={ + "q": tool_parameters.get("query"), + "format": "json", + "categories": tool_parameters.get("search_type", "general"), + }, + ) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') + raise Exception(f"Error {response.status_code}: {response.text}") res = response.json().get("results", []) if not res: diff --git a/api/core/tools/provider/builtin/serper/serper.py b/api/core/tools/provider/builtin/serper/serper.py index 2a42109373..cb1d090a9d 100644 --- a/api/core/tools/provider/builtin/serper/serper.py +++ b/api/core/tools/provider/builtin/serper/serper.py @@ -13,11 +13,8 @@ class SerperProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/serper/tools/serper_search.py b/api/core/tools/provider/builtin/serper/tools/serper_search.py index 24facaf4ec..7baebbf958 100644 --- a/api/core/tools/provider/builtin/serper/tools/serper_search.py +++ b/api/core/tools/provider/builtin/serper/tools/serper_search.py @@ -9,7 +9,6 @@ SERPER_API_URL = "https://google.serper.dev/search" class SerperSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledgeGraph" in response: @@ -17,28 +16,19 @@ class SerperSearchTool(BuiltinTool): result["description"] = response["knowledgeGraph"].get("description", "") if "organic" in response: result["organic"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - params = { - "q": tool_parameters['query'], - "gl": "us", - "hl": "en" - } - headers = { - 'X-API-KEY': self.runtime.credentials['serperapi_api_key'], - 'Content-Type': 'application/json' - } - response = requests.get(url=SERPER_API_URL, params=params,headers=headers) + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + params = {"q": tool_parameters["query"], "gl": "us", "hl": "en"} + headers = {"X-API-KEY": self.runtime.credentials["serperapi_api_key"], "Content-Type": "application/json"} + response = requests.get(url=SERPER_API_URL, params=params, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) return self.create_json_message(valuable_res) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py index 0df78280df..37a0b0755b 100644 --- a/api/core/tools/provider/builtin/siliconflow/siliconflow.py +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -14,6 +14,4 @@ class SiliconflowProvider(BuiltinToolProviderController): response = requests.get(url, headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError( - "SiliconFlow API key is invalid" - ) + raise ToolProviderCredentialValidationError("SiliconFlow API key is invalid") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index ed9f4be574..1b846624bd 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -5,17 +5,13 @@ import requests from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -FLUX_URL = ( - "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" -) +FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" class FluxTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -36,9 +32,5 @@ class FluxTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index e8134a6565..d6a0b03d1b 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -12,11 +12,9 @@ SDURL = { class StableDiffusionTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -43,9 +41,5 @@ class StableDiffusionTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py index f47557f2ef..85e0de7675 100644 --- a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py +++ b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py @@ -7,25 +7,27 @@ from core.tools.tool.builtin_tool import BuiltinTool class SlackWebhookTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Incoming Webhooks - API Document: https://api.slack.com/messaging/webhooks + Incoming Webhooks + API Document: https://api.slack.com/messaging/webhooks """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - webhook_url = tool_parameters.get('webhook_url', '') + webhook_url = tool_parameters.get("webhook_url", "") - if not webhook_url.startswith('https://hooks.slack.com/'): + if not webhook_url.startswith("https://hooks.slack.com/"): return self.create_text_message( - f'Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL') + f"Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL" + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { @@ -38,6 +40,7 @@ class SlackWebhookTool(BuiltinTool): return self.create_text_message("Text message was sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message through webhook. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message through webhook. {}".format(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py index cb8e69a59f..e0b1a58a3f 100644 --- a/api/core/tools/provider/builtin/spark/spark.py +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -29,12 +29,8 @@ class SparkProvider(BuiltinToolProviderController): # 0 success, pass else: - raise ToolProviderCredentialValidationError( - "image generate error, code:{}".format(code) - ) + raise ToolProviderCredentialValidationError("image generate error, code:{}".format(code)) except Exception as e: - raise ToolProviderCredentialValidationError( - "APPID APISecret APIKey is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("APPID APISecret APIKey is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index c7b0de014f..81d9e8d941 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -class AssembleHeaderException(Exception): +class AssembleHeaderError(Exception): def __init__(self, msg): self.message = msg class Url: - def __init__(this, host, path, schema): - this.host = host - this.path = path - this.schema = schema + def __init__(self, host, path, schema): + self.host = host + self.path = path + self.schema = schema # calculate sha256 and encode to base64 @@ -41,32 +41,31 @@ def parse_url(request_url): schema = request_url[: stidx + 3] edidx = host.index("/") if edidx <= 0: - raise AssembleHeaderException("invalid request url:" + request_url) + raise AssembleHeaderError("invalid request url:" + request_url) path = host[edidx:] host = host[:edidx] u = Url(host, path, schema) return u + def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): u = parse_url(request_url) host = u.host path = u.path now = datetime.now() date = format_date_time(mktime(now.timetuple())) - signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( - host, date, method, path - ) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path) signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( - encoding="utf-8" + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' ) + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") values = {"host": host, "date": date, "authorization": authorization} return request_url + "?" + urlencode(values) @@ -75,9 +74,7 @@ def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): def get_body(appid, text): body = { "header": {"app_id": appid, "uid": "123456789"}, - "parameter": { - "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} - }, + "parameter": {"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}}, "payload": {"message": {"text": [{"role": "user", "content": text}]}}, } return body @@ -85,13 +82,9 @@ def get_body(appid, text): def spark_response(text, appid, apikey, apisecret): host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" - url = assemble_ws_auth_url( - host, method="POST", api_key=apikey, api_secret=apisecret - ) + url = assemble_ws_auth_url(host, method="POST", api_key=apikey, api_secret=apisecret) content = get_body(appid, text) - response = requests.post( - url, json=content, headers={"content-type": "application/json"} - ).text + response = requests.post(url, json=content, headers={"content-type": "application/json"}).text return response @@ -105,19 +98,11 @@ class SparkImgGeneratorTool(BuiltinTool): invoke tools """ - if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( - "APPID" - ): + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get("APPID"): return self.create_text_message("APPID is required.") - if ( - "APISecret" not in self.runtime.credentials - or not self.runtime.credentials.get("APISecret") - ): + if "APISecret" not in self.runtime.credentials or not self.runtime.credentials.get("APISecret"): return self.create_text_message("APISecret is required.") - if ( - "APIKey" not in self.runtime.credentials - or not self.runtime.credentials.get("APIKey") - ): + if "APIKey" not in self.runtime.credentials or not self.runtime.credentials.get("APIKey"): return self.create_text_message("APIKey is required.") prompt = tool_parameters.get("prompt", "") @@ -130,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool): self.create_blob_message( blob=b64decode(image["base64_image"]), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/provider/builtin/spider/spider.py b/api/core/tools/provider/builtin/spider/spider.py index 5bcc56a724..5959555318 100644 --- a/api/core/tools/provider/builtin/spider/spider.py +++ b/api/core/tools/provider/builtin/spider/spider.py @@ -8,13 +8,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class SpiderProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - app = Spider(api_key=credentials['spider_api_key']) - app.scrape_url(url='https://spider.cloud') + app = Spider(api_key=credentials["spider_api_key"]) + app.scrape_url(url="https://spider.cloud") except AttributeError as e: # Handle cases where NoneType is not iterable, which might indicate API issues - if 'NoneType' in str(e) and 'not iterable' in str(e): - raise ToolProviderCredentialValidationError('API is currently down, try again in 15 minutes', str(e)) + if "NoneType" in str(e) and "not iterable" in str(e): + raise ToolProviderCredentialValidationError("API is currently down, try again in 15 minutes", str(e)) else: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) except Exception as e: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index f0ed64867a..4bc446a1a0 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -65,9 +65,7 @@ class Spider: :return: The JSON response or the raw response stream if stream is True. """ headers = self._prepare_headers(content_type) - response = self._post_request( - f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream - ) + response = self._post_request(f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream) if stream: return response @@ -76,9 +74,7 @@ class Spider: else: self._handle_error(response, f"post to {endpoint}") - def api_get( - self, endpoint: str, stream: bool, content_type: str = "application/json" - ): + def api_get(self, endpoint: str, stream: bool, content_type: str = "application/json"): """ Send a GET request to the specified endpoint. @@ -86,9 +82,7 @@ class Spider: :return: The JSON decoded response. """ headers = self._prepare_headers(content_type) - response = self._get_request( - f"https://api.spider.cloud/v1/{endpoint}", headers, stream - ) + response = self._get_request(f"https://api.spider.cloud/v1/{endpoint}", headers, stream) if response.status_code == 200: return response.json() else: @@ -120,14 +114,12 @@ class Spider: # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: - params["return_format"] = "markdown" + params["return_format"] = "markdown" # Set limit to 1 params["limit"] = 1 - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def crawl_url( self, @@ -150,9 +142,7 @@ class Spider: if "return_format" not in params: params["return_format"] = "markdown" - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def links( self, @@ -168,9 +158,7 @@ class Spider: :param params: Optional parameters for the link retrieval request. :return: JSON response containing the links. """ - return self.api_post( - "links", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("links", {"url": url, **(params or {})}, stream, content_type) def extract_contacts( self, @@ -207,9 +195,7 @@ class Spider: :param params: Optional parameters to guide the labeling process. :return: JSON response with labeled data. """ - return self.api_post( - "pipeline/label", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("pipeline/label", {"url": url, **(params or {})}, stream, content_type) def _prepare_headers(self, content_type: str = "application/json"): return { @@ -228,12 +214,8 @@ class Spider: return requests.delete(url, headers=headers, stream=stream) def _handle_error(self, response, action): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") - raise Exception( - f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}" - ) + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception( - f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}" - ) + raise Exception(f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}") diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py index 40736cd402..20d2daef55 100644 --- a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py @@ -6,41 +6,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # initialize the app object with the api key - app = Spider(api_key=self.runtime.credentials['spider_api_key']) + app = Spider(api_key=self.runtime.credentials["spider_api_key"]) + + url = tool_parameters["url"] + mode = tool_parameters["mode"] - url = tool_parameters['url'] - mode = tool_parameters['mode'] - options = { - 'limit': tool_parameters.get('limit', 0), - 'depth': tool_parameters.get('depth', 0), - 'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [], - 'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [], - 'readability': tool_parameters.get('readability', False), + "limit": tool_parameters.get("limit", 0), + "depth": tool_parameters.get("depth", 0), + "blacklist": tool_parameters.get("blacklist", "").split(",") if tool_parameters.get("blacklist") else [], + "whitelist": tool_parameters.get("whitelist", "").split(",") if tool_parameters.get("whitelist") else [], + "readability": tool_parameters.get("readability", False), } result = "" try: - if mode == 'scrape': + if mode == "scrape": scrape_result = app.scrape_url( - url=url, + url=url, params=options, ) for i in scrape_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" - elif mode == 'crawl': + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" + elif mode == "crawl": crawl_result = app.crawl_url( - url=tool_parameters['url'], + url=tool_parameters["url"], params=options, ) for i in crawl_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" except Exception as e: return self.create_text_message("An error occurred", str(e)) diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index b31d786178..f09d81ac27 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This class is responsible for providing the stability tool. """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ This method is responsible for validating the credentials. diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py index a4788fd869..c3b7edbefa 100644 --- a/api/core/tools/provider/builtin/stability/tools/base.py +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -9,26 +9,23 @@ class BaseStabilityAuthorization: """ This method is responsible for validating the credentials. """ - api_key = credentials.get('api_key', '') + api_key = credentials.get("api_key", "") if not api_key: - raise ToolProviderCredentialValidationError('API key is required.') - + raise ToolProviderCredentialValidationError("API key is required.") + response = requests.get( - URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + URL("https://api.stability.ai") / "v1" / "user" / "account", headers=self.generate_authorization_headers(credentials), - timeout=(5, 30) + timeout=(5, 30), ) if not response.ok: - raise ToolProviderCredentialValidationError('Invalid API key.') + raise ToolProviderCredentialValidationError("Invalid API key.") return True - + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: """ This method is responsible for generating the authorization headers. """ - return { - 'Authorization': f'Bearer {credentials.get("api_key", "")}' - } - \ No newline at end of file + return {"Authorization": f'Bearer {credentials.get("api_key", "")}'} diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 41236f7b43..6bcf315484 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): """ This class is responsible for providing the stable diffusion tool. """ + model_endpoint_map: dict[str, str] = { - 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + "sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "core": "https://api.stability.ai/v2beta/stable-image/generate/core", } def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -22,39 +23,34 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): Invoke the tool. """ payload = { - 'prompt': tool_parameters.get('prompt', ''), - 'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'), - 'mode': 'text-to-image', - 'seed': tool_parameters.get('seed', 0), - 'output_format': 'png', + "prompt": tool_parameters.get("prompt", ""), + "aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"), + "mode": "text-to-image", + "seed": tool_parameters.get("seed", 0), + "output_format": "png", } - model = tool_parameters.get('model', 'core') + model = tool_parameters.get("model", "core") - if model in ['sd3', 'sd3-turbo']: - payload['model'] = tool_parameters.get('model') + if model in {"sd3", "sd3-turbo"}: + payload["model"] = tool_parameters.get("model") - if not model == 'sd3-turbo': - payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + if model != "sd3-turbo": + payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( - self.model_endpoint_map[tool_parameters.get('model', 'core')], + self.model_endpoint_map[tool_parameters.get("model", "core")], headers={ - 'accept': 'image/*', + "accept": "image/*", **self.generate_authorization_headers(self.runtime.credentials), }, - files={ - key: (None, str(value)) for key, value in payload.items() - }, - timeout=(5, 30) + files={key: (None, str(value)) for key, value in payload.items()}, + timeout=(5, 30), ) if not response.status_code == 200: raise Exception(response.text) - + return self.create_blob_message( - blob=response.content, meta={ - 'mime_type': 'image/png' - }, - save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 317d705f7c..abaa297cf3 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -15,4 +15,3 @@ class StableDiffusionProvider(BuiltinToolProviderController): ).validate_models() except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 4be9207d66..64fdc961b4 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -18,19 +18,17 @@ DRAW_TEXT_OPTIONS = { # Prompts "prompt": "", "negative_prompt": "", - # "styles": [], - # Seeds + # "styles": [], + # Seeds "seed": -1, "subseed": -1, "subseed_strength": 0, "seed_resize_from_h": -1, "seed_resize_from_w": -1, - # Samplers "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", - # Latent Space Options "batch_size": 1, "n_iter": 1, @@ -42,9 +40,9 @@ DRAW_TEXT_OPTIONS = { # "tiling": True, "do_not_save_samples": False, "do_not_save_grid": False, - # "eta": 0, - # "denoising_strength": 0.75, - # "s_min_uncond": 0, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, # "s_churn": 0, # "s_tmax": 0, # "s_tmin": 0, @@ -73,7 +71,6 @@ DRAW_TEXT_OPTIONS = { "hr_negative_prompt": "", # Task Options # "force_task_id": "", - # Script Options # "script_name": "", "script_args": [], @@ -82,152 +79,150 @@ DRAW_TEXT_OPTIONS = { "save_images": False, "alwayson_scripts": {}, # "infotext": "", - } class StableDiffusionTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # base url - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - return self.create_text_message('Please input base_url') + return self.create_text_message("Please input base_url") - if tool_parameters.get('model'): - self.runtime.credentials['model'] = tool_parameters['model'] + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] - model = self.runtime.credentials.get('model', None) + model = self.runtime.credentials.get("model", None) if not model: - return self.create_text_message('Please input model') - + return self.create_text_message("Please input model") + # set model try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'options') - response = post(url, data=json.dumps({ - 'sd_model_checkpoint': model - })) + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post(url, data=json.dumps({"sd_model_checkpoint": model})) if response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") except Exception as e: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") # get image id and image variable - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") image_variable = self.get_default_image_variable() # Return text2img if there's no image ID or no image variable if not image_id or not image_variable: - return self.text2img(base_url=base_url,tool_parameters=tool_parameters) + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) # Proceed with image-to-image generation - return self.img2img(base_url=base_url,tool_parameters=tool_parameters) + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - validate models + validate models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - raise ToolProviderCredentialValidationError('Please input base_url') - model = self.runtime.credentials.get('model', None) + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) if not model: - raise ToolProviderCredentialValidationError('Please input model') + raise ToolProviderCredentialValidationError("Please input model") - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=10) if response.status_code == 404: # try draw a picture self._invoke( - user_id='test', + user_id="test", tool_parameters={ - 'prompt': 'a cat', - 'width': 1024, - 'height': 1024, - 'steps': 1, - 'lora': '', - } + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, ) elif response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to get models') + raise ToolProviderCredentialValidationError("Failed to get models") else: - models = [d['model_name'] for d in response.json()] + models = [d["model_name"] for d in response.json()] if len([d for d in models if d == model]) > 0: return self.create_text_message(json.dumps(models)) else: - raise ToolProviderCredentialValidationError(f'model {model} does not exist') + raise ToolProviderCredentialValidationError(f"model {model} does not exist") except Exception as e: - raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") def get_sd_models(self) -> list[str]: """ - get sd models + get sd models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=(2, 10)) if response.status_code != 200: return [] else: - return [d['model_name'] for d in response.json()] - except Exception as e: - return [] - - def get_sample_methods(self) -> list[str]: - """ - get sample method - """ - try: - base_url = self.runtime.credentials.get('base_url', None) - if not base_url: - return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') - response = get(url=api_url, timeout=(2, 10)) - if response.status_code != 200: - return [] - else: - return [d['name'] for d in response.json()] + return [d["model_name"] for d in response.json()] except Exception as e: return [] - def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def get_sample_methods(self) -> list[str]: """ - generate image + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d["name"] for d in response.json()] + except Exception as e: + return [] + + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image """ # Fetch the binary data of the image image_variable = self.get_default_image_variable() image_binary = self.get_variable_file(image_variable.name) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") # Convert image to RGB and save as PNG try: - with Image.open(io.BytesIO(image_binary)) as image: - with io.BytesIO() as buffer: - image.convert("RGB").save(buffer, format="PNG") - image_binary = buffer.getvalue() + with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() except Exception as e: return self.create_text_message(f"Failed to process the image: {str(e)}") # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) # set image options - model = tool_parameters.get('model', '') + model = tool_parameters.get("model", "") draw_options_image = { - "init_images": [b64encode(image_binary).decode('utf-8')], + "init_images": [b64encode(image_binary).decode("utf-8")], "denoising_strength": 0.9, "restore_faces": False, "script_args": [], "override_settings": {"sd_model_checkpoint": model}, - "resize_mode":0, + "resize_mode": 0, "image_cfg_scale": 0, # "mask": None, "mask_blur_x": 4, @@ -247,136 +242,149 @@ class StableDiffusionTool(BuiltinTool): draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt + draw_options["prompt"] = prompt try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") - def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - generate image + generate image """ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt - draw_options['override_settings']['sd_model_checkpoint'] = model + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model - try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") def get_runtime_parameters(self) -> list[ToolParameter]: parameters = [ - ToolParameter(name='prompt', - label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), - human_description=I18nObject( - en_US='Image prompt, you can check the official documentation of Stable Diffusion', - zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', - required=True), + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate" + " as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), ] if len(self.list_default_image_variables()) != 0: parameters.append( - ToolParameter(name='image_id', - label=I18nObject(en_US='image_id', zh_Hans='image_id'), - human_description=I18nObject( - en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', - zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', - required=True, - options=[ToolParameterOption( - value=i.name, - label=I18nObject(en_US=i.name, zh_Hans=i.name) - ) for i in self.list_default_image_variables()]) + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based" + " on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to" + " generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) ) - + if self.runtime.credentials: try: models = self.get_sd_models() if len(models) != 0: parameters.append( - ToolParameter(name='model', - label=I18nObject(en_US='Model', zh_Hans='Model'), - human_description=I18nObject( - en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=models[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in models]) + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) ) - + except: pass - + sample_methods = self.get_sample_methods() if len(sample_methods) != 0: parameters.append( - ToolParameter(name='sampler_name', - label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), - human_description=I18nObject( - en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=sample_methods[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in sample_methods]) + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], ) + ) return parameters diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index de64c84997..9680c633cc 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -11,16 +11,15 @@ class StackExchangeProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "intitle": "Test", - "sort": "relevance", + "sort": "relevance", "order": "desc", "site": "stackoverflow", "accepted": True, - "pagesize": 1 + "pagesize": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py index f8e1710844..5345320095 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py @@ -17,7 +17,9 @@ class FetchAnsByStackExQuesIDInput(BaseModel): class FetchAnsByStackExQuesIDTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = FetchAnsByStackExQuesIDInput(**tool_parameters) params = { @@ -26,7 +28,7 @@ class FetchAnsByStackExQuesIDTool(BuiltinTool): "order": input.order, "sort": input.sort, "pagesize": input.pagesize, - "page": input.page + "page": input.page, } response = requests.get(f"https://api.stackexchange.com/2.3/questions/{input.id}/answers", params=params) @@ -34,4 +36,4 @@ class FetchAnsByStackExQuesIDTool(BuiltinTool): if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py index 8436433c32..4a25a808ad 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py @@ -9,26 +9,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class SearchStackExQuestionsInput(BaseModel): intitle: str = Field(..., description="The search query.") - sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") + sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") order: str = Field(..., description="asc or desc") site: str = Field(..., description="The Stack Exchange site.") tagged: str = Field(None, description="Semicolon-separated tags to include.") nottagged: str = Field(None, description="Semicolon-separated tags to exclude.") - accepted: bool = Field(..., description="true for only accepted answers, false otherwise") + accepted: bool = Field(..., description="true for only accepted answers, false otherwise") pagesize: int = Field(..., description="Number of results per page") class SearchStackExQuestionsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = SearchStackExQuestionsInput(**tool_parameters) params = { "intitle": input.intitle, "sort": input.sort, - "order": input.order, + "order": input.order, "site": input.site, "accepted": input.accepted, - "pagesize": input.pagesize + "pagesize": input.pagesize, } if input.tagged: params["tagged"] = input.tagged @@ -40,4 +42,4 @@ class SearchStackExQuestionsTool(BuiltinTool): if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.py b/api/core/tools/provider/builtin/stepfun/stepfun.py index e809b04546..b24f730c95 100644 --- a/api/core/tools/provider/builtin/stepfun/stepfun.py +++ b/api/core/tools/provider/builtin/stepfun/stepfun.py @@ -13,13 +13,12 @@ class StepfunProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "cute girl, blue eyes, white hair, anime style", "size": "1024x1024", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py index c571f54675..0b92b122bf 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.py +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -9,61 +9,67 @@ from core.tools.tool.builtin_tool import BuiltinTool class StepfunTool(BuiltinTool): - """ Stepfun Image Generation Tool """ - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """Stepfun Image Generation Tool""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com') - base_url = str(URL(base_url) / 'v1') + base_url = self.runtime.credentials.get("stepfun_base_url", "https://api.stepfun.com") + base_url = str(URL(base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['stepfun_api_key'], + api_key=self.runtime.credentials["stepfun_api_key"], base_url=base_url, ) extra_body = {} - model = tool_parameters.get('model', 'step-1x-medium') + model = tool_parameters.get("model", "step-1x-medium") if not model: - return self.create_text_message('Please input model name') + return self.create_text_message("Please input model name") # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") - seed = tool_parameters.get('seed', 0) + seed = tool_parameters.get("seed", 0) if seed > 0: - extra_body['seed'] = seed - steps = tool_parameters.get('steps', 0) + extra_body["seed"] = seed + steps = tool_parameters.get("steps", 0) if steps > 0: - extra_body['steps'] = steps - negative_prompt = tool_parameters.get('negative_prompt', '') + extra_body["steps"] = steps + negative_prompt = tool_parameters.get("negative_prompt", "") if negative_prompt: - extra_body['negative_prompt'] = negative_prompt + extra_body["negative_prompt"] = negative_prompt # call openapi stepfun model response = client.images.generate( prompt=prompt, model=model, - size=tool_parameters.get('size', '1024x1024'), - n=tool_parameters.get('n', 1), - extra_body= extra_body + size=tool_parameters.get("size", "1024x1024"), + n=tool_parameters.get("n", 1), + extra_body=extra_body, ) print(response) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index e376d99d6b..a702b0a74e 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -13,7 +13,7 @@ class TavilyProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", "search_depth": "basic", @@ -22,9 +22,8 @@ class TavilyProvider(BuiltinToolProviderController): "include_raw_content": False, "max_results": 5, "include_domains": "", - "exclude_domains": "" + "exclude_domains": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py index 0200df3c8a..ca6d8633e4 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -36,15 +36,23 @@ class TavilySearch: """ params["api_key"] = self.api_key - if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None': - params['exclude_domains'] = params['exclude_domains'].split() + if ( + "exclude_domains" in params + and isinstance(params["exclude_domains"], str) + and params["exclude_domains"] != "None" + ): + params["exclude_domains"] = params["exclude_domains"].split() else: - params['exclude_domains'] = [] - if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None': - params['include_domains'] = params['include_domains'].split() + params["exclude_domains"] = [] + if ( + "include_domains" in params + and isinstance(params["include_domains"], str) + and params["include_domains"] != "None" + ): + params["include_domains"] = params["include_domains"].split() else: - params['include_domains'] = [] - + params["include_domains"] = [] + response = requests.post(f"{TAVILY_API_URL}/search", json=params) response.raise_for_status() return response.json() @@ -91,9 +99,7 @@ class TavilySearchTool(BuiltinTool): A tool for searching Tavily using a given query. """ - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Tavily search tool with the given user ID and tool parameters. @@ -115,4 +121,4 @@ class TavilySearchTool(BuiltinTool): if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") else: - return self.create_text_message(text=results) \ No newline at end of file + return self.create_text_message(text=results) diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.py b/api/core/tools/provider/builtin/tianditu/tianditu.py index 1f96be06b0..cb7d7bd8bb 100644 --- a/api/core/tools/provider/builtin/tianditu/tianditu.py +++ b/api/core/tools/provider/builtin/tianditu/tianditu.py @@ -12,10 +12,12 @@ class TiandituProvider(BuiltinToolProviderController): runtime={ "credentials": credentials, } - ).invoke(user_id='', - tool_parameters={ - 'content': '北京', - 'specify': '156110000', - }) + ).invoke( + user_id="", + tool_parameters={ + "content": "北京", + "specify": "156110000", + }, + ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py index 484a3768c8..690a0aed6f 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py @@ -8,26 +8,26 @@ from core.tools.tool.builtin_tool import BuiltinTool class GeocoderTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = 'http://api.tianditu.gov.cn/geocoder' - - keyword = tool_parameters.get('keyword', '') + base_url = "http://api.tianditu.gov.cn/geocoder" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + params = { - 'keyWord': keyword, + "keyWord": keyword, } - - result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json() + + result = requests.get(base_url + "?ds=" + json.dumps(params, ensure_ascii=False) + "&tk=" + tk).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py index 08a5b8ef42..798dd94d33 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py @@ -8,38 +8,51 @@ from core.tools.tool.builtin_tool import BuiltinTool class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/v2/search' - - keyword = tool_parameters.get('keyword', '') + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/v2/search" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - baseAddress = tool_parameters.get('baseAddress', '') + return self.create_text_message("Invalid parameter keyword") + + baseAddress = tool_parameters.get("baseAddress", "") if not baseAddress: - return self.create_text_message('Invalid parameter baseAddress') - - tk = self.runtime.credentials['tianditu_api_key'] - - base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json() - + return self.create_text_message("Invalid parameter baseAddress") + + tk = self.runtime.credentials["tianditu_api_key"] + + base_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": baseAddress, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + params = { - 'keyWord': keyword, - 'queryRadius': 5000, - 'queryType': 3, - 'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'], - 'start': 0, - 'count': 100, + "keyWord": keyword, + "queryRadius": 5000, + "queryType": 3, + "pointLonlat": base_coords["location"]["lon"] + "," + base_coords["location"]["lat"], + "start": 0, + "count": 100, } - - result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json() + + result = requests.get( + base_url + "?postStr=" + json.dumps(params, ensure_ascii=False) + "&type=query&tk=" + tk + ).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index ecac4404ca..aeaef08805 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -8,29 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/staticimage' - - keyword = tool_parameters.get('keyword', '') - if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - - keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json() - coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat'] - - result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content - return self.create_blob_message(blob=result, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/staticimage" + + keyword = tool_parameters.get("keyword", "") + if not keyword: + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + + keyword_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": keyword, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + coords = keyword_coords["location"]["lon"] + "," + keyword_coords["location"]["lat"] + + result = requests.get( + base_url + "?center=" + coords + "&markers=" + coords + "&width=400&height=300&zoom=14&tk=" + tk + ).content + + return self.create_blob_message( + blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 833ae194ef..e4df8d616c 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -9,9 +9,8 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CurrentTimeTool().invoke( - user_id='', + user_id="", tool_parameters={}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 90c01665e6..cc38739c16 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -8,21 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool class CurrentTimeTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get timezone - tz = tool_parameters.get('timezone', 'UTC') - fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z' - if tz == 'UTC': - return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}') - + tz = tool_parameters.get("timezone", "UTC") + fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" + if tz == "UTC": + return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + try: tz = pytz_timezone(tz) except: - return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime(fm)}') \ No newline at end of file + return self.create_text_message(f"Invalid timezone: {tz}") + return self.create_text_message(f"{datetime.now(tz).strftime(fm)}") diff --git a/api/core/tools/provider/builtin/time/tools/weekday.py b/api/core/tools/provider/builtin/time/tools/weekday.py index 4461cb5a32..b327e54e17 100644 --- a/api/core/tools/provider/builtin/time/tools/weekday.py +++ b/api/core/tools/provider/builtin/time/tools/weekday.py @@ -7,25 +7,26 @@ from core.tools.tool.builtin_tool import BuiltinTool class WeekdayTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Calculate the day of the week for a given date + Calculate the day of the week for a given date """ - year = tool_parameters.get('year') - month = tool_parameters.get('month') - day = tool_parameters.get('day') + year = tool_parameters.get("year") + month = tool_parameters.get("month") + day = tool_parameters.get("day") date_obj = self.convert_datetime(year, month, day) if not date_obj: - return self.create_text_message(f'Invalid date: Year {year}, Month {month}, Day {day}.') + return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.") weekday_name = calendar.day_name[date_obj.weekday()] month_name = calendar.month_name[month] readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" - return self.create_text_message(f'{readable_date} is {weekday_name}.') + return self.create_text_message(f"{readable_date} is {weekday_name}.") @staticmethod def convert_datetime(year, month, day) -> datetime | None: diff --git a/api/core/tools/provider/builtin/trello/tools/create_board.py b/api/core/tools/provider/builtin/trello/tools/create_board.py index 2655602afa..5a61d22157 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_board.py @@ -22,19 +22,15 @@ class CreateBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_name = tool_parameters.get("name") if not (api_key and token and board_name): return self.create_text_message("Missing required parameters: API key, token, or board name.") url = "https://api.trello.com/1/boards/" - query_params = { - 'name': board_name, - 'key': api_key, - 'token': token - } + query_params = {"name": board_name, "key": api_key, "token": token} try: response = requests.post(url, params=query_params) @@ -43,5 +39,6 @@ class CreateBoardTool(BuiltinTool): return self.create_text_message("Failed to create board") board = response.json() - return self.create_text_message(text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}") - + return self.create_text_message( + text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py index f5b156cb44..b32b0124dd 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py @@ -17,25 +17,22 @@ class CreateListOnBoardTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID and list name. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID and list name. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('id') - list_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("id") + list_name = tool_parameters.get("name") if not (api_key and token and board_id and list_name): return self.create_text_message("Missing required parameters: API key, token, board ID, or list name.") url = f"https://api.trello.com/1/boards/{board_id}/lists" - params = { - 'name': list_name, - 'key': api_key, - 'token': token - } + params = {"name": list_name, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -44,5 +41,6 @@ class CreateListOnBoardTool(BuiltinTool): return self.create_text_message("Failed to create list") new_list = response.json() - return self.create_text_message(text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}.") - + return self.create_text_message( + text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py index 74b73b40e5..e98efb81ca 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py @@ -17,20 +17,21 @@ class CreateNewCardOnBoardTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, including details for the new card. + tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, + including details for the new card. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") # Ensure required parameters are present - if 'name' not in tool_parameters or 'idList' not in tool_parameters: + if "name" not in tool_parameters or "idList" not in tool_parameters: return self.create_text_message("Missing required parameters: name or idList.") url = "https://api.trello.com/1/cards" - params = {**tool_parameters, 'key': api_key, 'token': token} + params = {**tool_parameters, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -39,5 +40,6 @@ class CreateNewCardOnBoardTool(BuiltinTool): except requests.exceptions.RequestException as e: return self.create_text_message("Failed to create card") - return self.create_text_message(text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}.") - + return self.create_text_message( + text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/delete_board.py b/api/core/tools/provider/builtin/trello/tools/delete_board.py index 29df3fda2d..7fc9d1f13c 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_board.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_board.py @@ -17,14 +17,15 @@ class DeleteBoardTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,4 +39,3 @@ class DeleteBoardTool(BuiltinTool): return self.create_text_message("Failed to delete board") return self.create_text_message(text=f"Board with ID {board_id} deleted successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/delete_card.py b/api/core/tools/provider/builtin/trello/tools/delete_card.py index 2ced5f6c14..1de98d639e 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_card.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_card.py @@ -17,14 +17,15 @@ class DeleteCardByIdTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the card ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the card ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") @@ -38,4 +39,3 @@ class DeleteCardByIdTool(BuiltinTool): return self.create_text_message("Failed to delete card") return self.create_text_message(text=f"Card with ID {card_id} has been successfully deleted.") - diff --git a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py index f9d554c6fb..0c5ed9ea85 100644 --- a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py +++ b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py @@ -28,9 +28,7 @@ class FetchAllBoardsTool(BuiltinTool): token = self.runtime.credentials.get("trello_api_token") if not (api_key and token): - return self.create_text_message( - "Missing Trello API key or token in credentials." - ) + return self.create_text_message("Missing Trello API key or token in credentials.") # Including board filter in the request if provided board_filter = tool_parameters.get("boards", "open") @@ -48,7 +46,5 @@ class FetchAllBoardsTool(BuiltinTool): return self.create_text_message("No boards found in Trello.") # Creating a string with both board names and IDs - boards_info = ", ".join( - [f"{board['name']} (ID: {board['id']})" for board in boards] - ) + boards_info = ", ".join([f"{board['name']} (ID: {board['id']})" for board in boards]) return self.create_text_message(text=f"Boards: {boards_info}") diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py index 5678d8f8d7..cabc7ce093 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py @@ -17,14 +17,15 @@ class GetBoardActionsTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,6 +39,7 @@ class GetBoardActionsTool(BuiltinTool): except requests.exceptions.RequestException as e: return self.create_text_message("Failed to retrieve board actions") - actions_summary = "\n".join([f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions]) + actions_summary = "\n".join( + [f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions] + ) return self.create_text_message(text=f"Actions for Board ID {board_id}:\n{actions_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py index ee6cb065e5..fe42cd9c5c 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py @@ -17,14 +17,15 @@ class GetBoardByIdTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -63,4 +64,3 @@ class GetBoardByIdTool(BuiltinTool): f"Background Color: {board['prefs']['backgroundColor']}" ) return details - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py index 1abb688750..ff2b1221e7 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py @@ -17,14 +17,15 @@ class GetBoardCardsTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +41,3 @@ class GetBoardCardsTool(BuiltinTool): cards_summary = "\n".join([f"{card['name']} (ID: {card['id']})" for card in cards]) return self.create_text_message(text=f"Cards for Board ID {board_id}:\n{cards_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py index 375ead5b1d..3d7f9f4ad1 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py @@ -17,15 +17,16 @@ class GetFilteredBoardCardsTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID and filter. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID and filter. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') - filter = tool_parameters.get('filter') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") + filter = tool_parameters.get("filter") if not (api_key and token and board_id and filter): return self.create_text_message("Missing required parameters: API key, token, board ID, or filter.") @@ -40,5 +41,6 @@ class GetFilteredBoardCardsTool(BuiltinTool): return self.create_text_message("Failed to retrieve filtered cards") card_details = "\n".join([f"{card['name']} (ID: {card['id']})" for card in filtered_cards]) - return self.create_text_message(text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}") - + return self.create_text_message( + text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py index 7b9b9cf24b..ccf404068f 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py @@ -17,14 +17,15 @@ class GetListsFromBoardTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +41,3 @@ class GetListsFromBoardTool(BuiltinTool): lists_info = "\n".join([f"{list['name']} (ID: {list['id']})" for list in lists]) return self.create_text_message(text=f"Lists on Board ID {board_id}:\n{lists_info}") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_board.py b/api/core/tools/provider/builtin/trello/tools/update_board.py index 7ad6ac2e64..1e358b00f4 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_board.py +++ b/api/core/tools/provider/builtin/trello/tools/update_board.py @@ -17,14 +17,15 @@ class UpdateBoardByIdTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, including board ID and updates. + tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, + including board ID and updates. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.pop('boardId', None) + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.pop("boardId", None) if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -33,8 +34,8 @@ class UpdateBoardByIdTool(BuiltinTool): # Removing parameters not intended for update action or with None value params = {k: v for k, v in tool_parameters.items() if v is not None} - params['key'] = api_key - params['token'] = token + params["key"] = api_key + params["token"] = token try: response = requests.put(url, params=params) @@ -44,4 +45,3 @@ class UpdateBoardByIdTool(BuiltinTool): updated_board = response.json() return self.create_text_message(text=f"Board '{updated_board['name']}' updated successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_card.py b/api/core/tools/provider/builtin/trello/tools/update_card.py index 417344350c..d25fcbafaa 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_card.py +++ b/api/core/tools/provider/builtin/trello/tools/update_card.py @@ -17,22 +17,23 @@ class UpdateCardByIdTool(BuiltinTool): Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, including the card ID and updates. + tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, + including the card ID and updates. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") # Constructing the URL and the payload for the PUT request url = f"https://api.trello.com/1/cards/{card_id}" - params = {k: v for k, v in tool_parameters.items() if v is not None and k != 'id'} - params.update({'key': api_key, 'token': token}) + params = {k: v for k, v in tool_parameters.items() if v is not None and k != "id"} + params.update({"key": api_key, "token": token}) try: response = requests.put(url, params=params) diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index 84ecd20803..e0dca50ec9 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -9,17 +9,17 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class TrelloProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: """Validate Trello API credentials by making a test API call. - + Args: credentials (dict[str, Any]): The Trello API credentials to validate. - + Raises: ToolProviderCredentialValidationError: If the credentials are invalid. """ api_key = credentials.get("trello_api_key") token = credentials.get("trello_api_token") url = f"https://api.trello.com/1/members/me?key={api_key}&token={token}" - + try: response = requests.get(url) response.raise_for_status() # Raises an HTTPError for bad responses @@ -32,4 +32,3 @@ class TrelloProvider(BuiltinToolProviderController): except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors raise ToolProviderCredentialValidationError("Error validating Trello credentials") - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 1c52589956..5ee839baa5 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -32,17 +32,14 @@ class TwilioAPIWrapper(BaseModel): must be empty. """ - @field_validator('client', mode='before') + @field_validator("client", mode="before") @classmethod def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: from twilio.rest import Client except ImportError: - raise ImportError( - "Could not import twilio python package. " - "Please install it with `pip install twilio`." - ) + raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.") account_sid = values.get("account_sid") auth_token = values.get("auth_token") values["from_number"] = values.get("from_number") @@ -75,7 +72,8 @@ class SendMessageTool(BuiltinTool): tool_parameters (Dict[str, Any]): The parameters required for sending the message. Returns: - Union[ToolInvokeMessage, List[ToolInvokeMessage]]: The result of invoking the tool, which includes the status of the message sending operation. + Union[ToolInvokeMessage, List[ToolInvokeMessage]]: The result of invoking the tool, + which includes the status of the message sending operation. """ def _invoke( @@ -91,9 +89,7 @@ class SendMessageTool(BuiltinTool): if to_number.startswith("whatsapp:"): from_number = f"whatsapp: {from_number}" - twilio = TwilioAPIWrapper( - account_sid=account_sid, auth_token=auth_token, from_number=from_number - ) + twilio = TwilioAPIWrapper(account_sid=account_sid, auth_token=auth_token, from_number=from_number) # Sending the message through Twilio result = twilio.run(message, to_number) diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 06f276053a..b1d100aad9 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -14,7 +14,7 @@ class TwilioProvider(BuiltinToolProviderController): account_sid = credentials["account_sid"] auth_token = credentials["auth_token"] from_number = credentials["from_number"] - + # Initialize twilio client client = Client(account_sid, auth_token) @@ -27,4 +27,3 @@ class TwilioProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index a6efb0f79a..c90d766e48 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -38,7 +38,7 @@ class VannaTool(BuiltinTool): vn = VannaDefault(model=model, api_key=api_key) db_type = tool_parameters.get("db_type", "") - if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: if not db_name: return self.create_text_message("Please input database name") if not username: diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index ab1fd71df5..84724e921a 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -13,13 +13,13 @@ class VannaProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "model": "chinook", "db_type": "SQLite", "url": "https://vanna.ai/Chinook.sqlite", - "query": "What are the top 10 customers by sales?" + "query": "What are the top 10 customers by sales?", }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py index 1506ac0c9d..8effa9818a 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py @@ -1 +1 @@ -VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file +VECTORIZER_ICON_PNG = "iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC" # noqa: E501 diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index c6ec198034..4bd601c0bd 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -10,65 +10,60 @@ from core.tools.tool.builtin_tool import BuiltinTool class VectorizerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - api_key_name = self.runtime.credentials.get('api_key_name', None) - api_key_value = self.runtime.credentials.get('api_key_value', None) - mode = tool_parameters.get('mode', 'test') - if mode == 'production': - mode = 'preview' + api_key_name = self.runtime.credentials.get("api_key_name", None) + api_key_value = self.runtime.credentials.get("api_key_value", None) + mode = tool_parameters.get("mode", "test") + if mode == "production": + mode = "preview" if not api_key_name or not api_key_value: - raise ToolProviderCredentialValidationError('Please input api key name and value') + raise ToolProviderCredentialValidationError("Please input api key name and value") - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") if not image_id: - return self.create_text_message('Please input image id') - - if image_id.startswith('__test_'): + return self.create_text_message("Please input image id") + + if image_id.startswith("__test_"): image_binary = b64decode(VECTORIZER_ICON_PNG) else: - image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + image_binary = self.get_variable_file(self.VariableKey.IMAGE) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") response = post( - 'https://vectorizer.ai/api/v1/vectorize', - files={ - 'image': image_binary - }, - data={ - 'mode': mode - } if mode == 'test' else {}, - auth=(api_key_name, api_key_value), - timeout=30 + "https://vectorizer.ai/api/v1/vectorize", + files={"image": image_binary}, + data={"mode": mode} if mode == "test" else {}, + auth=(api_key_name, api_key_value), + timeout=30, ) if response.status_code != 200: raise Exception(response.text) - + return [ - self.create_text_message('the vectorized svg is saved as an image.'), - self.create_blob_message(blob=response.content, - meta={'mime_type': 'image/svg+xml'}) + self.create_text_message("the vectorized svg is saved as an image."), + self.create_blob_message(blob=response.content, meta={"mime_type": "image/svg+xml"}), ] - + def get_runtime_parameters(self) -> list[ToolParameter]: """ override the runtime parameters """ return [ ToolParameter.get_simple_instance( - name='image_id', - llm_description=f'the image id that you want to vectorize, \ + name="image_id", + llm_description=f"the image id that you want to vectorize, \ and the image id should be specified in \ - {[i.name for i in self.list_default_image_variables()]}', + {[i.name for i in self.list_default_image_variables()]}", type=ToolParameter.ToolParameterType.SELECT, required=True, - options=[i.name for i in self.list_default_image_variables()] + options=[i.name for i in self.list_default_image_variables()], ) ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 3f89a83500..3b868572f9 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -13,12 +13,8 @@ class VectorizerProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "mode": "test", - "image_id": "__test_123" - }, + user_id="", + tool_parameters={"mode": "test", "image_id": "__test_123"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 3d098e6768..12670b4b8b 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -6,23 +6,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class WebscraperTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: - url = tool_parameters.get('url', '') - user_agent = tool_parameters.get('user_agent', '') + url = tool_parameters.get("url", "") + user_agent = tool_parameters.get("user_agent", "") if not url: - return self.create_text_message('Please input url') + return self.create_text_message("Please input url") # get webpage result = self.get_url(url, user_agent=user_agent) - if tool_parameters.get('generate_summary'): + if tool_parameters.get("generate_summary"): # summarize and return return self.create_text_message(self.summary(user_id=user_id, content=result)) else: diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 1e60fdb293..3c51393ac6 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -13,12 +13,11 @@ class WebscraperProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - 'url': 'https://www.google.com', - 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + "url": "https://www.google.com", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/websearch/tools/job_search.py b/api/core/tools/provider/builtin/websearch/tools/job_search.py index 9128305922..293f4f6329 100644 --- a/api/core/tools/provider/builtin/websearch/tools/job_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/job_search.py @@ -50,14 +50,16 @@ class SerplyApi: for job in jobs[:10]: try: string.append( - "\n".join([ - f"Position: {job['position']}", - f"Employer: {job['employer']}", - f"Location: {job['location']}", - f"Link: {job['link']}", - f"""Highest: {", ".join(list(job["highlights"]))}""", - "---", - ]) + "\n".join( + [ + f"Position: {job['position']}", + f"Employer: {job['employer']}", + f"Location: {job['location']}", + f"Link: {job['link']}", + f"""Highest: {", ".join(list(job["highlights"]))}""", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/news_search.py b/api/core/tools/provider/builtin/websearch/tools/news_search.py index e9c0744f05..9b5482fe18 100644 --- a/api/core/tools/provider/builtin/websearch/tools/news_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/news_search.py @@ -53,13 +53,15 @@ class SerplyApi: r = requests.get(entry["link"]) final_link = r.history[-1].headers["Location"] string.append( - "\n".join([ - f"Title: {entry['title']}", - f"Link: {final_link}", - f"Source: {entry['source']['title']}", - f"Published: {entry['published']}", - "---", - ]) + "\n".join( + [ + f"Title: {entry['title']}", + f"Link: {final_link}", + f"Source: {entry['source']['title']}", + f"Published: {entry['published']}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py index 0030a03c06..798d059b51 100644 --- a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py @@ -55,14 +55,16 @@ class SerplyApi: link = article["link"] authors = [author["name"] for author in article["author"]["authors"]] string.append( - "\n".join([ - f"Title: {article['title']}", - f"Link: {link}", - f"Description: {article['description']}", - f"Cite: {article['cite']}", - f"Authors: {', '.join(authors)}", - "---", - ]) + "\n".join( + [ + f"Title: {article['title']}", + f"Link: {link}", + f"Description: {article['description']}", + f"Cite: {article['cite']}", + f"Authors: {', '.join(authors)}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/web_search.py b/api/core/tools/provider/builtin/websearch/tools/web_search.py index 4f57c27caf..fe363ac7a4 100644 --- a/api/core/tools/provider/builtin/websearch/tools/web_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/web_search.py @@ -49,12 +49,14 @@ class SerplyApi: for result in results: try: string.append( - "\n".join([ - f"Title: {result['title']}", - f"Link: {result['link']}", - f"Description: {result['description'].strip()}", - "---", - ]) + "\n".join( + [ + f"Title: {result['title']}", + f"Link: {result['link']}", + f"Description: {result['description'].strip()}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py index fb44b70f4e..545d9f4f8d 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py @@ -8,41 +8,41 @@ from core.tools.utils.uuid_utils import is_valid_uuid class WecomGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - message_type = tool_parameters.get('message_type', 'text') - if message_type == 'markdown': + message_type = tool_parameters.get("message_type", "text") + if message_type == "markdown": payload = { - "msgtype": 'markdown', + "msgtype": "markdown", "markdown": { "content": content, - } + }, } else: payload = { - "msgtype": 'text', + "msgtype": "text", "text": { "content": content, - } + }, } - api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' + api_url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'key': hook_key, + "key": hook_key, } try: @@ -51,6 +51,7 @@ class WecomGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 0796cd2392..cb88e9519a 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -27,7 +27,7 @@ class WikipediaAPIWrapper: self.doc_content_chars_max = doc_content_chars_max def run(self, query: str, lang: str = "") -> str: - if lang in wikipedia.languages().keys(): + if lang in wikipedia.languages(): self.lang = lang wikipedia.set_lang(self.lang) @@ -83,7 +83,6 @@ class WikipediaQueryRun: class WikiPediaSearchTool(BuiltinTool): - def _invoke( self, user_id: str, diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index f8038714a5..178bf7b0ce 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -11,11 +11,10 @@ class WikiPediaProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "misaka mikoto", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 8cb9c10ddf..9dc5bed824 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -8,29 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class WolframAlphaTool(BuiltinTool): - _base_url = 'https://api.wolframalpha.com/v2/query' + _base_url = "https://api.wolframalpha.com/v2/query" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - appid = self.runtime.credentials.get('appid', '') + return self.create_text_message("Please input query") + appid = self.runtime.credentials.get("appid", "") if not appid: - raise ToolProviderCredentialValidationError('Please input appid') - - params = { - 'appid': appid, - 'input': query, - 'includepodid': 'Result', - 'format': 'plaintext', - 'output': 'json' - } + raise ToolProviderCredentialValidationError("Please input appid") + + params = {"appid": appid, "input": query, "includepodid": "Result", "format": "plaintext", "output": "json"} finished = False result = None @@ -45,34 +40,33 @@ class WolframAlphaTool(BuiltinTool): response_data = response.json() except Exception as e: raise ToolInvokeError(str(e)) - - if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: - query_result = response_data.get('queryresult', {}) - if query_result.get('error'): - if 'msg' in query_result['error']: - if query_result['error']['msg'] == 'Invalid appid': - raise ToolProviderCredentialValidationError('Invalid appid') - raise ToolInvokeError('Failed to invoke tool') - - if 'didyoumeans' in response_data['queryresult']: - # get the most likely interpretation - query = '' - max_score = 0 - for didyoumean in response_data['queryresult']['didyoumeans']: - if float(didyoumean['score']) > max_score: - query = didyoumean['val'] - max_score = float(didyoumean['score']) - params['input'] = query + if "success" not in response_data["queryresult"] or response_data["queryresult"]["success"] != True: + query_result = response_data.get("queryresult", {}) + if query_result.get("error"): + if "msg" in query_result["error"]: + if query_result["error"]["msg"] == "Invalid appid": + raise ToolProviderCredentialValidationError("Invalid appid") + raise ToolInvokeError("Failed to invoke tool") + + if "didyoumeans" in response_data["queryresult"]: + # get the most likely interpretation + query = "" + max_score = 0 + for didyoumean in response_data["queryresult"]["didyoumeans"]: + if float(didyoumean["score"]) > max_score: + query = didyoumean["val"] + max_score = float(didyoumean["score"]) + + params["input"] = query else: finished = True - if 'souces' in response_data['queryresult']: - return self.create_link_message(response_data['queryresult']['sources']['url']) - elif 'pods' in response_data['queryresult']: - result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext'] + if "souces" in response_data["queryresult"]: + return self.create_link_message(response_data["queryresult"]["sources"]["url"]) + elif "pods" in response_data["queryresult"]: + result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"] if not finished or not result: - return self.create_text_message('No result found') + return self.create_text_message("No result found") return self.create_text_message(result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index ef1aac7ff2..7be288b538 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -13,11 +13,10 @@ class GoogleProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "1+2+....+111", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index cf511ea894..f044fbe540 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -10,27 +10,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - symbol = tool_parameters.get('symbol', '') + symbol = tool_parameters.get("symbol", "") if not symbol: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") stock_data = download(symbol, start=time_range[0], end=time_range[1]) max_segments = min(15, len(stock_data)) @@ -41,30 +42,29 @@ class YahooFinanceAnalyticsTool(BuiltinTool): end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data) segment_data = stock_data.iloc[start_idx:end_idx] segment_summary = { - 'Start Date': segment_data.index[0], - 'End Date': segment_data.index[-1], - 'Average Close': segment_data['Close'].mean(), - 'Average Volume': segment_data['Volume'].mean(), - 'Average Open': segment_data['Open'].mean(), - 'Average High': segment_data['High'].mean(), - 'Average Low': segment_data['Low'].mean(), - 'Average Adj Close': segment_data['Adj Close'].mean(), - 'Max Close': segment_data['Close'].max(), - 'Min Close': segment_data['Close'].min(), - 'Max Volume': segment_data['Volume'].max(), - 'Min Volume': segment_data['Volume'].min(), - 'Max Open': segment_data['Open'].max(), - 'Min Open': segment_data['Open'].min(), - 'Max High': segment_data['High'].max(), - 'Min High': segment_data['High'].min(), + "Start Date": segment_data.index[0], + "End Date": segment_data.index[-1], + "Average Close": segment_data["Close"].mean(), + "Average Volume": segment_data["Volume"].mean(), + "Average Open": segment_data["Open"].mean(), + "Average High": segment_data["High"].mean(), + "Average Low": segment_data["Low"].mean(), + "Average Adj Close": segment_data["Adj Close"].mean(), + "Max Close": segment_data["Close"].max(), + "Min Close": segment_data["Close"].min(), + "Max Volume": segment_data["Volume"].max(), + "Min Volume": segment_data["Volume"].min(), + "Max Open": segment_data["Open"].max(), + "Min Open": segment_data["Open"].min(), + "Max High": segment_data["High"].max(), + "Min High": segment_data["High"].min(), } - + summary_data.append(segment_summary) summary_df = pd.DataFrame(summary_data) - + try: return self.create_text_message(str(summary_df.to_dict())) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - \ No newline at end of file + return self.create_text_message("There is a internet connection problem. Please try again later.") diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 4f2922ef3e..ff820430f9 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -8,40 +8,39 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self,user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - ''' - invoke tools - ''' - - query = tool_parameters.get('symbol', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.run(ticker=query, user_id=user_id) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') + return self.create_text_message("There is a internet connection problem. Please try again later.") def run(self, ticker: str, user_id: str) -> ToolInvokeMessage: company = yfinance.Ticker(ticker) try: if company.isin is None: - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") except (HTTPError, ReadTimeout, ConnectionError): - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") links = [] try: - links = [n['link'] for n in company.news if n['type'] == 'STORY'] + links = [n["link"] for n in company.news if n["type"] == "STORY"] except (HTTPError, ReadTimeout, ConnectionError): if not links: - return self.create_text_message(f'There is nothing about {ticker} ticker') + return self.create_text_message(f"There is nothing about {ticker} ticker") if not links: - return self.create_text_message(f'No news found for company that searched with {ticker} ticker.') - - result = '\n\n'.join([ - self.get_url(link) for link in links - ]) + return self.create_text_message(f"No news found for company that searched with {ticker} ticker.") + + result = "\n\n".join([self.get_url(link) for link in links]) return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index 262fff3b25..dfc7e46047 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -8,19 +8,20 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('symbol', '') + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.create_text_message(self.run(ticker=query)) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - + return self.create_text_message("There is a internet connection problem. Please try again later.") + def run(self, ticker: str) -> str: - return str(Ticker(ticker).info) \ No newline at end of file + return str(Ticker(ticker).info) diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index 96dbc6c3d0..8d82084e76 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -11,11 +11,10 @@ class YahooFinanceProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "ticker": "MSFT", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 7a9b9fce4a..95dec2eac9 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -8,60 +8,67 @@ from core.tools.tool.builtin_tool import BuiltinTool class YoutubeVideosAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - channel = tool_parameters.get('channel', '') + channel = tool_parameters.get("channel", "") if not channel: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") - if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']: - return self.create_text_message('Please input api key') + if "google_api_key" not in self.runtime.credentials or not self.runtime.credentials["google_api_key"]: + return self.create_text_message("Please input api key") - youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) + youtube = build("youtube", "v3", developerKey=self.runtime.credentials["google_api_key"]) # try to get channel id - search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute() - channel_id = search_results['items'][0]['id']['channelId'] + search_results = youtube.search().list(q=channel, type="channel", order="relevance", part="id").execute() + channel_id = search_results["items"][0]["id"]["channelId"] start_date, end_date = time_range - start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') - end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") + end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") # get videos - time_range_videos = youtube.search().list( - part='snippet', channelId=channel_id, order='date', type='video', - publishedAfter=start_date, - publishedBefore=end_date - ).execute() + time_range_videos = ( + youtube.search() + .list( + part="snippet", + channelId=channel_id, + order="date", + type="video", + publishedAfter=start_date, + publishedBefore=end_date, + ) + .execute() + ) def extract_video_data(video_list): data = [] - for video in video_list['items']: - video_id = video['id']['videoId'] - video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute() - title = video_info['items'][0]['snippet']['title'] - views = video_info['items'][0]['statistics']['viewCount'] - data.append({'Title': title, 'Views': views}) + for video in video_list["items"]: + video_id = video["id"]["videoId"] + video_info = youtube.videos().list(part="snippet,statistics", id=video_id).execute() + title = video_info["items"][0]["snippet"]["title"] + views = video_info["items"][0]["statistics"]["viewCount"] + data.append({"Title": title, "Views": views}) return data summary = extract_video_data(time_range_videos) - + return self.create_text_message(str(summary)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 83a4fccb32..aad876491c 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -11,7 +11,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "channel": "TOKYO GIRLS COLLECTION", "start_date": "2020-01-01", @@ -20,4 +20,3 @@ class YahooFinanceProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 8dd543b00a..2a966d3999 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -20,22 +20,22 @@ class BuiltinToolProviderController(ToolProviderController): tools: list[BuiltinTool] = Field(default_factory=list) def __init__(self, **data: Any) -> None: - if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: + if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: super().__init__(**data) return - + # load provider yaml - provider = self.__class__.__module__.split('.')[-1] - yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') + provider = self.__class__.__module__.split(".")[-1] + yaml_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, f"{provider}.yaml") try: provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') + raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") - if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: + if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None: # set credentials name - for credential_name in provider_yaml['credentials_for_provider']: - provider_yaml['credentials_for_provider'][credential_name]['name'] = credential_name + for credential_name in provider_yaml["credentials_for_provider"]: + provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name super().__init__(**{ 'identity': provider_yaml['identity'], @@ -44,13 +44,13 @@ class BuiltinToolProviderController(ToolProviderController): def _get_builtin_tools(self) -> list[BuiltinTool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ if self.tools: return self.tools - + provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") # get all the yaml files in the tool path @@ -63,10 +63,12 @@ class BuiltinToolProviderController(ToolProviderController): # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) + module_name=f"core.tools.provider.builtin.{provider}.tools.{tool_name}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "builtin", provider, "tools", f"{tool_name}.py" + ), + parent_type=BuiltinTool, + ) tool["identity"]["provider"] = provider tools.append(assistant_tool_class(**tool)) @@ -75,70 +77,69 @@ class BuiltinToolProviderController(ToolProviderController): def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ if not self.credentials_schema: return {} - + return self.credentials_schema.copy() def get_tools(self) -> list[BuiltinTool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ return self._get_builtin_tools() def get_tool(self, tool_name: str) -> BuiltinTool | None: """ - returns the tool that the provider can provide + returns the tool that the provider can provide """ return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) @property def need_credentials(self) -> bool: """ - returns whether the provider needs credentials + returns whether the provider needs credentials - :return: whether the provider needs credentials + :return: whether the provider needs credentials """ - return self.credentials_schema is not None and \ - len(self.credentials_schema) != 0 + return self.credentials_schema is not None and len(self.credentials_schema) != 0 @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN @property def tool_labels(self) -> list[str]: """ - returns the labels of the provider + returns the labels of the provider - :return: labels of the provider + :return: labels of the provider """ label_enums = self._get_tool_labels() return [default_tool_label_dict[label].name for label in label_enums] def _get_tool_labels(self) -> list[ToolLabelEnum]: """ - returns the labels of the provider + returns the labels of the provider """ return self.identity.tags or [] def validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ # validate credentials format self.validate_credentials_format(credentials) @@ -149,9 +150,9 @@ class BuiltinToolProviderController(ToolProviderController): @abstractmethod def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ pass diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index 057f3060ed..bfdc161af6 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -21,81 +21,85 @@ class ToolProviderController(BaseModel, ABC): def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ return self.credentials_schema.copy() - + @abstractmethod def get_tool(self, tool_name: str) -> Tool: """ - returns a tool that the provider can provide + returns a tool that the provider can provide - :return: tool + :return: tool """ pass @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ - validate the format of the credentials of the provider and set the default value if needed + validate the format of the credentials of the provider and set the default value if needed - :param credentials: the credentials of the tool + :param credentials: the credentials of the tool """ credentials_schema = self.credentials_schema if credentials_schema is None: return - + credentials_need_to_validate: dict[str, ProviderConfig] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] for credential_name in credentials: if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.identity.name}" + ) + # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ProviderConfig.Type.SECRET_INPUT or \ - credential_schema == ProviderConfig.Type.TEXT_INPUT: + if credential_schema in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + elif credential_schema.type == ProviderConfig.Type.SELECT: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + options = credential_schema.options if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: credential_schema = credentials_need_to_validate[credential_name] if credential_schema.required: - raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + # the credential is not set currently, set the default value if needed if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \ - credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \ - credential_schema.type == ProviderConfig.Type.SELECT: + if credential_schema.type in { + ProviderConfig.Type.SECRET_INPUT, + ProviderConfig.Type.TEXT_INPUT, + ProviderConfig.Type.SELECT, + }: default_value = str(default_value) credentials[credential_name] = default_value - \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index 19bf3d9a86..6f80767bd5 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -34,29 +34,25 @@ class WorkflowToolProviderController(ToolProviderController): tools: list[WorkflowTool] = Field(default_factory=list) @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app if not app: - raise ValueError('app not found') + raise ValueError("app not found") - controller = WorkflowToolProviderController(**{ - 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', - 'name': db_provider.label, - 'label': { - 'en_US': db_provider.label, - 'zh_Hans': db_provider.label + controller = WorkflowToolProviderController( + **{ + "identity": { + "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", + "name": db_provider.label, + "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, - }, - 'credentials_schema': {}, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": {}, + "provider_id": db_provider.id or "", + } + ) # init tools @@ -70,10 +66,10 @@ class WorkflowToolProviderController(ToolProviderController): def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ - get db provider tool - :param db_provider: the db provider - :param app: the app - :return: the tool + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool """ workflow: Workflow | None = db.session.query(Workflow).filter( Workflow.app_id == db_provider.app_id, @@ -81,7 +77,7 @@ class WorkflowToolProviderController(ToolProviderController): ).first() if not workflow: - raise ValueError('workflow not found') + raise ValueError("workflow not found") # fetch start node graph: Mapping = workflow.graph_dict @@ -106,51 +102,34 @@ class WorkflowToolProviderController(ToolProviderController): parameter_type = None options = [] if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: - raise ValueError(f'unsupported variable type {variable.type}') + raise ValueError(f"unsupported variable type {variable.type}") parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] if variable.type == VariableEntityType.SELECT and variable.options: options = [ - ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in variable.options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in variable.options ] workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=variable.label, - zh_Hans=variable.label - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=variable.label, zh_Hans=variable.label), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=parameter_type, form=parameter.form, llm_description=parameter.description, required=variable.required, options=options, - default=variable.default + default=variable.default, ) ) elif features.file_upload: workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=parameter.name, - zh_Hans=parameter.name - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=ToolParameter.ToolParameterType.FILE, llm_description=parameter.description, required=False, @@ -158,45 +137,39 @@ class WorkflowToolProviderController(ToolProviderController): ) ) else: - raise ValueError('variable not found') + raise ValueError("variable not found") return WorkflowTool( identity=ToolIdentity( - author=user.name if user else '', + author=user.name if user else "", name=db_provider.name, - label=I18nObject( - en_US=db_provider.label, - zh_Hans=db_provider.label - ), + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), provider=self.provider_id, icon=db_provider.icon, ), description=ToolDescription( - human=I18nObject( - en_US=db_provider.description, - zh_Hans=db_provider.description - ), + human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), llm=db_provider.description, ), parameters=workflow_tool_parameters, is_team_authorization=True, workflow_app_id=app.id, workflow_entities={ - 'app': app, - 'workflow': workflow, + "app": app, + "workflow": workflow, }, version=db_provider.version, workflow_call_depth=0, - label=db_provider.label + label=db_provider.label, ) def get_tools(self, tenant_id: str) -> list[WorkflowTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools @@ -219,10 +192,10 @@ class WorkflowToolProviderController(ToolProviderController): def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 2c7064f97d..87f2514ce2 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -6,15 +6,15 @@ from urllib.parse import urlencode import httpx -import core.helper.ssrf_proxy as ssrf_proxy +from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool API_TOOL_DEFAULT_TIMEOUT = ( - int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')), - int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) + int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), + int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), ) @@ -25,31 +25,32 @@ class ApiTool(Tool): Api tool """ - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, - runtime=Tool.Runtime(**runtime) + runtime=Tool.Runtime(**runtime), ) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], - format_only: bool = False) -> str: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str: """ - validate the credentials for Api tool + validate the credentials for Api tool """ - # assemble validate request and request parameters + # assemble validate request and request parameters headers = self.assembling_request(parameters) if format_only: - return '' + return "" response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response @@ -62,30 +63,30 @@ class ApiTool(Tool): headers = {} credentials = self.runtime.credentials or {} - if 'auth_type' not in credentials: - raise ToolProviderCredentialValidationError('Missing auth_type') + if "auth_type" not in credentials: + raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials['auth_type'] == 'api_key': - api_key_header = 'api_key' + if credentials["auth_type"] == "api_key": + api_key_header = "api_key" - if 'api_key_header' in credentials: - api_key_header = credentials['api_key_header'] + if "api_key_header" in credentials: + api_key_header = credentials["api_key_header"] - if 'api_key_value' not in credentials: - raise ToolProviderCredentialValidationError('Missing api_key_value') - elif not isinstance(credentials['api_key_value'], str): - raise ToolProviderCredentialValidationError('api_key_value must be a string') + if "api_key_value" not in credentials: + raise ToolProviderCredentialValidationError("Missing api_key_value") + elif not isinstance(credentials["api_key_value"], str): + raise ToolProviderCredentialValidationError("api_key_value must be a string") - if 'api_key_header_prefix' in credentials: - api_key_header_prefix = credentials['api_key_header_prefix'] - if api_key_header_prefix == 'basic' and credentials['api_key_value']: - credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}' - elif api_key_header_prefix == 'bearer' and credentials['api_key_value']: - credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' - elif api_key_header_prefix == 'custom': + if "api_key_header_prefix" in credentials: + api_key_header_prefix = credentials["api_key_header_prefix"] + if api_key_header_prefix == "basic" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials['api_key_value'] + headers[api_key_header] = credentials["api_key_value"] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] for parameter in needed_parameters: @@ -99,13 +100,13 @@ class ApiTool(Tool): def validate_and_parse_response(self, response: httpx.Response) -> str: """ - validate the response + validate the response """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return 'Empty response from the tool, please check your parameters and try again.' + return "Empty response from the tool, please check your parameters and try again." try: response = response.json() try: @@ -115,21 +116,22 @@ class ApiTool(Tool): except Exception as e: return response.text else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") @staticmethod def get_parameter_value(parameter, parameters): - if parameter['name'] in parameters: - return parameters[parameter['name']] - elif parameter.get('required', False): + if parameter["name"] in parameters: + return parameters[parameter["name"]] + elif parameter.get("required", False): raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") else: - return (parameter.get('schema', {}) or {}).get('default', '') + return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request(self, url: str, method: str, headers: dict[str, Any], - parameters: dict[str, Any]) -> httpx.Response: + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: """ - do http request depending on api bundle + do http request depending on api bundle """ method = method.lower() @@ -139,29 +141,30 @@ class ApiTool(Tool): cookies = {} # check parameters - for parameter in self.api_bundle.openapi.get('parameters', []): + for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) - if parameter['in'] == 'path': - path_params[parameter['name']] = value + if parameter["in"] == "path": + path_params[parameter["name"]] = value - elif parameter['in'] == 'query': - if value !='': params[parameter['name']] = value + elif parameter["in"] == "query": + if value != "": + params[parameter["name"]] = value - elif parameter['in'] == 'cookie': - cookies[parameter['name']] = value + elif parameter["in"] == "cookie": + cookies[parameter["name"]] = value - elif parameter['in'] == 'header': - headers[parameter['name']] = value + elif parameter["in"] == "header": + headers[parameter["name"]] = value # check if there is a request body and handle it - if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None: + if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: # handle json request body - if 'content' in self.api_bundle.openapi['requestBody']: - for content_type in self.api_bundle.openapi['requestBody']['content']: - headers['Content-Type'] = content_type - body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "content" in self.api_bundle.openapi["requestBody"]: + for content_type in self.api_bundle.openapi["requestBody"]["content"]: + headers["Content-Type"] = content_type + body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: # convert type @@ -170,63 +173,71 @@ class ApiTool(Tool): raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" ) - elif 'default' in property: - body[name] = property['default'] + elif "default" in property: + body[name] = property["default"] else: body[name] = None break # replace path parameters for name, value in path_params.items(): - url = url.replace(f'{{{name}}}', f'{value}') + url = url.replace(f"{{{name}}}", f"{value}") # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored - if 'Content-Type' in headers: - if headers['Content-Type'] == 'application/json': + if "Content-Type" in headers: + if headers["Content-Type"] == "application/json": body = json.dumps(body) - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + elif headers["Content-Type"] == "application/x-www-form-urlencoded": body = urlencode(body) else: body = body - if method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body, - timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) + if method in {"get", "head", "post", "put", "delete", "patch"}: + response = getattr(ssrf_proxy, method)( + url, + params=params, + headers=headers, + cookies=cookies, + data=body, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) return response else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") - def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], - max_recursive=10) -> Any: + def _convert_body_property_any_of( + self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 + ) -> Any: if max_recursive <= 0: raise Exception("Max recursion depth reached") for option in any_of or []: try: - if 'type' in option: + if "type" in option: # Attempt to convert the value based on the type. - if option['type'] == 'integer' or option['type'] == 'int': + if option["type"] == "integer" or option["type"] == "int": return int(value) - elif option['type'] == 'number': - if '.' in str(value): + elif option["type"] == "number": + if "." in str(value): return float(value) else: return int(value) - elif option['type'] == 'string': + elif option["type"] == "string": return str(value) - elif option['type'] == 'boolean': - if str(value).lower() in ['true', '1']: + elif option["type"] == "boolean": + if str(value).lower() in {"true", "1"}: return True - elif str(value).lower() in ['false', '0']: + elif str(value).lower() in {"false", "0"}: return False else: continue # Not a boolean, try next option - elif option['type'] == 'null' and not value: + elif option["type"] == "null" and not value: return None else: continue # Unsupported type, try next option - elif 'anyOf' in option and isinstance(option['anyOf'], list): + elif "anyOf" in option and isinstance(option["anyOf"], list): # Recursive call to handle nested anyOf - return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1) + return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) except ValueError: continue # Conversion failed, try next option # If no option succeeded, you might want to return the value as is or raise an error @@ -234,23 +245,23 @@ class ApiTool(Tool): def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: try: - if 'type' in property: - if property['type'] == 'integer' or property['type'] == 'int': + if "type" in property: + if property["type"] == "integer" or property["type"] == "int": return int(value) - elif property['type'] == 'number': + elif property["type"] == "number": # check if it is a float - if '.' in str(value): + if "." in str(value): return float(value) else: return int(value) - elif property['type'] == 'string': + elif property["type"] == "string": return str(value) - elif property['type'] == 'boolean': + elif property["type"] == "boolean": return bool(value) - elif property['type'] == 'null': + elif property["type"] == "null": if value is None: return None - elif property['type'] == 'object' or property['type'] == 'array': + elif property["type"] == "object" or property["type"] == "array": if isinstance(value, str): try: # an array str like '[1,2]' also can convert to list [1,2] through json.loads @@ -265,8 +276,8 @@ class ApiTool(Tool): return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") - elif 'anyOf' in property and isinstance(property['anyOf'], list): - return self._convert_body_property_any_of(property, value, property['anyOf']) + elif "anyOf" in property and isinstance(property["anyOf"], list): + return self._convert_body_property_any_of(property, value, property["anyOf"]) except ValueError as e: return value diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index ad7a88838b..8edaf7c0e6 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.entities.tool_entities import ToolProviderType @@ -16,40 +15,38 @@ Please summarize the text you got. class BuiltinTool(Tool): """ - Builtin tool + Builtin tool - :param meta: the meta data of a tool call processing + :param meta: the meta data of a tool call processing """ - def invoke_model( - self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str] - ) -> LLMResult: + def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ - invoke model + invoke model - :param model_config: the model config - :param prompt_messages: the prompt messages - :param stop: the stop words - :return: the model result + :param model_config: the model config + :param prompt_messages: the prompt messages + :param stop: the stop words + :return: the model result """ # invoke model return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, - tool_type='builtin', + tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, ) - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN - + def get_max_tokens(self) -> int: """ - get max tokens + get max tokens - :param model_config: the model config - :return: the max tokens + :param model_config: the model config + :return: the max tokens """ return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, @@ -57,39 +54,34 @@ class BuiltinTool(Tool): def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: """ - get prompt tokens + get prompt tokens - :param prompt_messages: the prompt messages - :return: the tokens + :param prompt_messages: the prompt messages + :return: the tokens """ - return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id, - prompt_messages=prompt_messages - ) + return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=content) - ]) < max_tokens * 0.6: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6: return content - + def get_prompt_tokens(content: str) -> int: - return self.get_prompt_tokens(prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ]) - + return self.get_prompt_tokens( + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)] + ) + def summarize(content: str) -> str: - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)], + stop=[], + ) return summary.message.content - lines = content.split('\n') + lines = content.split("\n") new_lines = [] # split long line into multiple lines for i in range(len(lines)): @@ -100,8 +92,8 @@ class BuiltinTool(Tool): new_lines.append(line) elif get_prompt_tokens(line) > max_tokens * 0.7: while get_prompt_tokens(line) > max_tokens * 0.7: - new_lines.append(line[:int(max_tokens * 0.5)]) - line = line[int(max_tokens * 0.5):] + new_lines.append(line[: int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5) :] new_lines.append(line) else: new_lines.append(line) @@ -125,17 +117,15 @@ class BuiltinTool(Tool): summary = summarize(message) summaries.append(summary) - result = '\n'.join(summaries) + result = "\n".join(summaries) - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=result) - ]) > max_tokens * 0.7: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7: return self.summary(user_id=user_id, content=result) - + return result - + def get_url(self, url: str, user_agent: str = None) -> str: """ - get url + get url """ - return get_url(url, user_agent=user_agent) \ No newline at end of file + return get_url(url, user_agent=user_agent) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index d6ecc9257b..6073b8e92e 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -14,14 +14,11 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -31,6 +28,7 @@ class DatasetMultiRetrieverToolInput(BaseModel): class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" + name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " @@ -38,27 +36,26 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_provider_name: str reranking_model_name: str - @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", - tenant_id=tenant_id, - dataset_ids=dataset_ids, - **kwargs + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs ) def _run(self, query: str) -> str: threads = [] all_documents = [] for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'all_documents': all_documents, - 'hit_callbacks': self.hit_callbacks - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -69,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): tenant_id=self.tenant_id, provider=self.reranking_provider_name, model_type=ModelType.RERANK, - model=self.reranking_model_name + model=self.reranking_model_name, ) rerank_runner = RerankModelRunner(rerank_model_instance) @@ -80,62 +77,61 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 @@ -144,13 +140,18 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join(document_context_list)) - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler]): + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) if not dataset: return [] @@ -159,31 +160,31 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): hit_callback.on_query(query, dataset.id) # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) if documents: all_documents.extend(documents) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index 62e97a0230..dad8c77357 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -9,6 +9,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa class DatasetRetrieverBaseTool(BaseModel, ABC): """Tool for querying a Dataset.""" + name: str = "dataset" description: str = "use this to retrieve a dataset. " tenant_id: str diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 220e4baa85..8dc60408c9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -8,15 +7,12 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'reranking_mode': 'reranking_model', - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, } @@ -26,64 +22,63 @@ class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" + name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") return cls( name=f"dataset_{dataset.id.replace('-', '_')}", tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, - **kwargs + **kwargs, ) def _run(self, query: str) -> str: - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == self.dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) if not dataset: - return '' + return "" for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) return str("\n".join([document.page_content for document in documents])) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) else: documents = [] @@ -92,25 +87,26 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in documents] - segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: @@ -118,36 +114,36 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): resource_number = 1 for segment in sorted_segments: context = {} - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) - + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) \ No newline at end of file + return str("\n".join(document_context_list)) diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index e1f53c0338..9f41b5d5eb 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -21,13 +21,14 @@ class DatasetRetrieverTool(Tool): retrieval_tool: DatasetRetrieverBaseTool @staticmethod - def get_dataset_tools(tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler - ) -> list['DatasetRetrieverTool']: + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: """ get dataset tool """ @@ -49,7 +50,7 @@ class DatasetRetrieverTool(Tool): retrieve_config=retrieve_config, return_resource=return_resource, invoke_from=invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) if retrieval_tools is None or len(retrieval_tools) == 0: return [] @@ -62,13 +63,13 @@ class DatasetRetrieverTool(Tool): for retrieval_tool in retrieval_tools: tool = DatasetRetrieverTool( retrieval_tool=retrieval_tool, - identity=ToolIdentity(provider='', author='', name=retrieval_tool.name, label=I18nObject(en_US='', zh_Hans='')), + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), parameters=[], is_team_authorization=True, - description=ToolDescription( - human=I18nObject(en_US='', zh_Hans=''), - llm=retrieval_tool.description), - runtime=DatasetRetrieverTool.Runtime() + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + runtime=DatasetRetrieverTool.Runtime(), ) tools.append(tool) @@ -77,16 +78,18 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters(self) -> list[ToolParameter]: return [ - ToolParameter(name='query', - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Query for the dataset to be used to retrieve the dataset.', - required=True, - default=''), + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + ), ] - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL @@ -94,7 +97,7 @@ class DatasetRetrieverTool(Tool): """ invoke dataset retriever tool """ - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: yield self.create_text_message(text='please input query') else: diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 6f21afdb35..49f9bf68ea 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -35,7 +35,7 @@ class Tool(BaseModel, ABC): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - @field_validator('parameters', mode='before') + @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: return v or [] @@ -63,10 +63,10 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VARIABLE_KEY(Enum): - IMAGE = 'image' + class VariableKey(Enum): + IMAGE = "image" - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ fork a new tool with meta data @@ -142,7 +142,7 @@ class Tool(BaseModel, ABC): if not self.variables: return None - return self.get_variable(self.VARIABLE_KEY.IMAGE) + return self.get_variable(self.VariableKey.IMAGE) def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ @@ -189,7 +189,7 @@ class Tool(BaseModel, ABC): result = [] for variable in self.variables.pool: - if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): + if variable.name.startswith(self.VariableKey.IMAGE.value): result.append(variable) return result @@ -207,12 +207,16 @@ class Tool(BaseModel, ABC): ) if isinstance(result, ToolInvokeMessage): + def single_generator(): yield result + return single_generator() elif isinstance(result, list): + def generator(): yield from result + return generator() else: return result @@ -232,7 +236,9 @@ class Tool(BaseModel, ABC): return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: @@ -288,37 +294,34 @@ class Tool(BaseModel, ABC): return parameters - def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: """ create an image message :param image: the url of the image :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=ToolInvokeMessage.TextMessage(text=image), - save_as=save_as) + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as + ) def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, - message=None, - meta={ - 'file_var': file_var - }, - save_as='') + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.FILE_VAR, message=None, meta={"file_var": file_var}, save_as="" + ) - def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: """ create a link message :param link: the url of the link :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text=link), - save_as=save_as) + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as + ) - def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: """ create a text message @@ -326,12 +329,10 @@ class Tool(BaseModel, ABC): :return: the text message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage(text=text), - save_as=save_as + type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as ) - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = '') -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage: """ create a blob message @@ -339,10 +340,10 @@ class Tool(BaseModel, ABC): :return: the blob message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=blob), - meta=meta, - save_as=save_as + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta=meta, + save_as=save_as, ) def create_json_message(self, object: dict) -> ToolInvokeMessage: @@ -350,6 +351,5 @@ class Tool(BaseModel, ABC): create a json message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, - message=ToolInvokeMessage.JsonMessage(json_object=object) + type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) ) diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index c2c204b508..42ceffb834 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -14,6 +14,7 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) + class WorkflowTool(Tool): workflow_app_id: str version: str @@ -26,19 +27,18 @@ class WorkflowTool(Tool): """ Workflow tool. """ + def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ return ToolProviderType.WORKFLOW - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Generator[ToolInvokeMessage, None, None]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: """ - invoke the tool + invoke the tool """ app = self._get_app(app_id=self.workflow_app_id) workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) @@ -47,40 +47,42 @@ class WorkflowTool(Tool): tool_parameters, files = self._transform_args(tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() - assert self.runtime and self.runtime.invoke_from + assert self.runtime + assert self.runtime.invoke_from result = generator.generate( - app_model=app, - workflow=workflow, - user=self._get_user(user_id), - args={ - 'inputs': tool_parameters, - 'files': files - }, + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, - workflow_thread_pool_id=self.thread_pool_id + workflow_thread_pool_id=self.thread_pool_id, ) - data = result.get('data', {}) + data = result.get("data", {}) - if data.get('error'): - raise Exception(data.get('error')) - - outputs = data.get('outputs', {}) + if data.get("error"): + raise Exception(data.get("error")) + + if data.get("error"): + raise Exception(data.get("error")) + + outputs = data.get("outputs", {}) outputs, files = self._extract_files(outputs) for file in files: yield self.create_file_var_message(file) - + yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ - get the user by user id + get the user by user id """ user = db.session.query(EndUser).filter(EndUser.id == user_id).first() @@ -88,16 +90,16 @@ class WorkflowTool(Tool): user = db.session.query(Account).filter(Account.id == user_id).first() if not user: - raise ValueError('user not found') + raise ValueError("user not found") return user - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=deepcopy(self.identity), @@ -108,45 +110,44 @@ class WorkflowTool(Tool): workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, version=self.version, - label=self.label + label=self.label, ) - + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ - get the workflow by app id and version + get the workflow by app id and version """ if not version: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version != 'draft' - ).order_by(Workflow.created_at.desc()).first() + workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .order_by(Workflow.created_at.desc()) + .first() + ) else: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version == version - ).first() + workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: - raise ValueError('workflow not found or not published') + raise ValueError("workflow not found or not published") return workflow - + def _get_app(self, app_id: str) -> App: """ - get the app by app id + get the app by app id """ app = db.session.query(App).filter(App.id == app_id).first() if not app: - raise ValueError('app not found') + raise ValueError("app not found") return app - + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ - transform the tool parameters + transform the tool parameters - :param tool_parameters: the tool parameters - :return: tool_parameters, files + :param tool_parameters: the tool parameters + :return: tool_parameters, files """ parameter_rules = self.get_all_runtime_parameters() parameters_result = {} @@ -159,15 +160,15 @@ class WorkflowTool(Tool): file_var_list = [FileVar(**f) for f in file] for file_var in file_var_list: file_dict: dict[str, Any] = { - 'transfer_method': file_var.transfer_method.value, - 'type': file_var.type.value, + "transfer_method": file_var.transfer_method.value, + "type": file_var.type.value, } if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict['tool_file_id'] = file_var.related_id + file_dict["tool_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict['upload_file_id'] = file_var.related_id + file_dict["upload_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict['url'] = file_var.preview_url + file_dict["url"] = file_var.preview_url files.append(file_dict) except Exception as e: @@ -176,13 +177,13 @@ class WorkflowTool(Tool): parameters_result[parameter.name] = tool_parameters.get(parameter.name) return parameters_result, files - + def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: """ - extract files from the result + extract files from the result - :param result: the result - :return: the result, files + :param result: the result + :return: the result, files """ files = [] result = {} @@ -190,7 +191,7 @@ class WorkflowTool(Tool): if isinstance(value, list): has_file = False for item in value: - if isinstance(item, dict) and item.get('__variant') == 'FileVar': + if isinstance(item, dict) and item.get("__variant") == "FileVar": try: files.append(FileVar(**item)) has_file = True @@ -201,4 +202,4 @@ class WorkflowTool(Tool): result[key] = value - return result, files \ No newline at end of file + return result, files diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index c5c50db8d0..8a4be51d28 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -34,12 +34,17 @@ class ToolEngine: """ Tool runtime engine take care of the tool executions. """ + @staticmethod def agent_invoke( - tool: Tool, tool_parameters: Union[str, dict], - user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -48,31 +53,29 @@ class ToolEngine: if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.get_runtime_parameters() or [] + parameter + for parameter in tool.get_runtime_parameters() or [] if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: - tool_parameters = { - parameters[0].name: tool_parameters - } + tool_parameters = {parameters[0].name: tool_parameters} else: raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: # hit the callback handler - agent_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) messages = ToolEngine._invoke(tool, tool_parameters, user_id) invocation_meta_dict: dict[str, ToolInvokeMeta] = {} - def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]): + def message_callback( + invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] + ): for message in messages: if isinstance(message, ToolInvokeMeta): - invocation_meta_dict['meta'] = message + invocation_meta_dict["meta"] = message else: yield message @@ -80,22 +83,19 @@ class ToolEngine: messages=message_callback(invocation_meta_dict, messages), user_id=user_id, tenant_id=tenant_id, - conversation_id=message.conversation_id + conversation_id=message.conversation_id, ) # extract binary data from tool invoke message binary_files = ToolEngine._extract_tool_response_binary(messages) # create message file message_files = ToolEngine._create_message_files( - tool_messages=binary_files, - agent_message=message, - invoke_from=invoke_from, - user_id=user_id + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) plain_text = ToolEngine._convert_tool_response_to_str(messages) - meta = invocation_meta_dict['meta'] + meta = invocation_meta_dict["meta"] # hit the callback handler agent_tool_callback.on_tool_end( @@ -103,7 +103,7 @@ class ToolEngine: tool_inputs=tool_parameters, tool_outputs=plain_text, message_id=message.id, - trace_manager=trace_manager + trace_manager=trace_manager, ) # transform tool invoke message to get LLM friendly message @@ -111,14 +111,10 @@ class ToolEngine: except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.identity.name}" agent_tool_callback.on_tool_error(e) - except ( - ToolParameterValidationError - ) as e: + except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) except ToolInvokeError as e: @@ -136,21 +132,20 @@ class ToolEngine: return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - thread_pool_id: Optional[str] = None - ) -> Generator[ToolInvokeMessage, None, None]: + def workflow_invoke( + tool: Tool, + tool_parameters: dict[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: """ Workflow invokes the tool with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 @@ -172,20 +167,17 @@ class ToolEngine: except Exception as e: workflow_tool_callback.on_tool_error(e) raise e - + @staticmethod - def plugin_invoke(tool: Tool, tool_parameters: dict, user_id: str, - callback: DifyPluginCallbackHandler - ) -> Generator[ToolInvokeMessage, None, None]: + def plugin_invoke( + tool: Tool, tool_parameters: dict, user_id: str, callback: DifyPluginCallbackHandler + ) -> Generator[ToolInvokeMessage, None, None]: """ Plugin invokes the tool with the given arguments. """ try: # hit the callback handler - callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) response = tool.invoke(user_id, tool_parameters) @@ -200,10 +192,11 @@ class ToolEngine: except Exception as e: callback.on_tool_error(e) raise e - + @staticmethod - def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ - -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: + def _invoke( + tool: Tool, tool_parameters: dict, user_id: str + ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: """ Invoke the tool with the given arguments. """ @@ -211,13 +204,17 @@ class ToolEngine: raise ValueError("missing runtime in tool") started_at = datetime.now(timezone.utc) - meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ - 'tool_name': tool.identity.name, - 'tool_provider': tool.identity.provider, - 'tool_provider_type': tool.tool_provider_type().value, - 'tool_parameters': deepcopy(tool.runtime.runtime_parameters), - 'tool_icon': tool.identity.icon - }) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.identity.name, + "tool_provider": tool.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.identity.icon, + }, + ) try: yield from tool.invoke(user_id, tool_parameters) except Exception as e: @@ -233,77 +230,82 @@ class ToolEngine: """ Handle tool response """ - result = '' + result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: result += cast(ToolInvokeMessage.TextMessage, response.message).text elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it." - elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: - result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." + result += ( + f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}." + + " please tell user to check it." + ) + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + result += ( + "image has been created and sent to user already, " + + "you do not need to create it, just tell the user to check it now." + ) elif response.type == ToolInvokeMessage.MessageType.JSON: - result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}." + result += f"tool response: { + json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False) + }." else: result += f"tool response: {response.message}." return result - + @staticmethod - def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]: + def _extract_tool_response_binary( + tool_response: Generator[ToolInvokeMessage, None, None], + ) -> Generator[ToolInvokeMessageBinary, None, None]: """ Extract tool response binary """ for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: mimetype = None if not response.meta: raise ValueError("missing meta data") - if response.meta.get('mime_type'): - mimetype = response.meta.get('mime_type') + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") else: try: url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) extension = url.suffix - guess_type_result, _ = guess_type(f'a{extension}') + guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result except Exception: pass - + if not mimetype: - mimetype = 'image/jpeg' - + mimetype = "image/jpeg" + yield ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'image/jpeg'), + mimetype=response.meta.get("mime_type", "image/jpeg"), url=cast(ToolInvokeMessage.TextMessage, response.message).text, save_as=response.save_as, ) elif response.type == ToolInvokeMessage.MessageType.BLOB: if not response.meta: raise ValueError("missing meta data") - + yield ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), + mimetype=response.meta.get("mime_type", "octet/stream"), url=cast(ToolInvokeMessage.TextMessage, response.message).text, save_as=response.save_as, ) elif response.type == ToolInvokeMessage.MessageType.LINK: # check if there is a mime type in meta - if response.meta and 'mime_type' in response.meta: + if response.meta and "mime_type" in response.meta: yield ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream", url=cast(ToolInvokeMessage.TextMessage, response.message).text, save_as=response.save_as, ) - + @staticmethod def _create_message_files( - tool_messages: Iterable[ToolInvokeMessageBinary], - agent_message: Message, - invoke_from: InvokeFrom, - user_id: str + tool_messages: Iterable[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str ) -> list[tuple[MessageFile, str]]: """ Create message file @@ -314,29 +316,29 @@ class ToolEngine: result = [] for message in tool_messages: - file_type = 'bin' - if 'image' in message.mimetype: - file_type = 'image' - elif 'video' in message.mimetype: - file_type = 'video' - elif 'audio' in message.mimetype: - file_type = 'audio' - elif 'text' in message.mimetype: - file_type = 'text' - elif 'pdf' in message.mimetype: - file_type = 'pdf' - elif 'zip' in message.mimetype: - file_type = 'archive' + file_type = "bin" + if "image" in message.mimetype: + file_type = "image" + elif "video" in message.mimetype: + file_type = "video" + elif "audio" in message.mimetype: + file_type = "audio" + elif "text" in message.mimetype: + file_type = "text" + elif "pdf" in message.mimetype: + file_type = "pdf" + elif "zip" in message.mimetype: + file_type = "archive" # ... message_file = MessageFile( message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE.value, - belongs_to='assistant', + belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), created_by=user_id, ) @@ -344,11 +346,8 @@ class ToolEngine: db.session.commit() db.session.refresh(message_file) - result.append(( - message_file.id, - message.save_as - )) + result.append((message_file.id, message.save_as)) db.session.close() - return result \ No newline at end of file + return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 078c58c662..f123e69c19 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -27,24 +27,24 @@ class ToolFileManager: sign file to get a temporary url """ base_url = dify_config.FILES_URL - file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -62,9 +62,9 @@ class ToolFileManager: """ create file """ - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, file_binary) tool_file = ToolFile( @@ -90,10 +90,10 @@ class ToolFileManager: response = get(file_url) response.raise_for_status() blob = response.content - mimetype = guess_type(file_url)[0] or 'octet/stream' - extension = guess_extension(mimetype) or '.bin' + mimetype = guess_type(file_url)[0] or "octet/stream" + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, blob) tool_file = ToolFile( @@ -166,13 +166,12 @@ class ToolFileManager: # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] else: tool_file_id = None - tool_file: ToolFile | None = ( db.session.query(ToolFile) .filter( @@ -216,4 +215,4 @@ class ToolFileManager: # init tool_file_parser from core.file.tool_file_parser import tool_file_manager -tool_file_manager['manager'] = ToolFileManager +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 97788a7a07..2a5a2944ef 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -15,7 +15,7 @@ class ToolLabelManager: """ tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] return list(set(tool_labels)) - + @classmethod def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): """ @@ -26,20 +26,20 @@ class ToolLabelManager: if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id == provider_id - ).delete() + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: - db.session.add(ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - )) + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) db.session.commit() @@ -53,12 +53,16 @@ class ToolLabelManager: elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) return [label.label_name for label in labels] @@ -75,22 +79,20 @@ class ToolLabelManager: """ if not tool_providers: return {} - + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError('Unsupported tool type') - + raise ValueError("Unsupported tool type") + provider_ids = [controller.provider_id for controller in tool_providers] - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id.in_(provider_ids) - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) - tool_labels = { - label.tool_id: [] for label in labels - } + tool_labels = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) - return tool_labels \ No newline at end of file + return tool_labels diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 34a6bd0dd8..108b74018d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -42,29 +42,29 @@ class ToolManager: @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: """ - get the builtin provider + get the builtin provider - :param provider: the name of the provider - :return: the provider + :param provider: the name of the provider + :return: the provider """ if len(cls._builtin_providers) == 0: # init the builtin providers cls.load_builtin_providers_cache() if provider not in cls._builtin_providers: - raise ToolProviderNotFoundError(f'builtin provider {provider} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") return cls._builtin_providers[provider] @classmethod def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None: """ - get the builtin tool + get the builtin tool - :param provider: the name of the provider - :param tool_name: the name of the tool + :param provider: the name of the provider + :param tool_name: the name of the tool - :return: the provider, the tool + :return: the provider, the tool """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) @@ -72,21 +72,23 @@ class ToolManager: return tool @classmethod - def get_tool_runtime(cls, provider_type: ToolProviderType, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool, WorkflowTool]: + def get_tool_runtime( + cls, + provider_type: ToolProviderType, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, ApiTool, WorkflowTool]: """ - get the tool runtime + get the tool runtime - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ if provider_type == ToolProviderType.BUILT_IN: builtin_tool = cls.get_builtin_tool(provider_id, tool_name) @@ -96,91 +98,114 @@ class ToolManager: # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - })) + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ), + ) # get credentials - builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_id, - ).first() + builtin_provider: BuiltinToolProvider | None = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_id, + ) + .first() + ) if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") # decrypt the credentials credentials = builtin_provider.credentials controller = cls.get_builtin_provider(provider_id) tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, + tenant_id=tenant_id, config=controller.get_credentials_schema(), provider_type=controller.provider_type.value, - provider_identity=controller.identity.name + provider_identity=controller.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) - return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'runtime_parameters': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - })) + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "runtime_parameters": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ), + ) elif provider_type == ToolProviderType.API: if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) # decrypt the credentials tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, + tenant_id=tenant_id, config=api_provider.get_credentials_schema(), provider_type=api_provider.provider_type.value, - provider_identity=api_provider.identity.name + provider_identity=api_provider.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) - return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - })) + return cast( + ApiTool, + api_provider.get_tool(tool_name).fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ), + ) elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() - - if workflow_provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') - - controller = ToolTransformService.workflow_provider_to_controller( - db_provider=workflow_provider + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() ) - return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - })) + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + + return cast( + WorkflowTool, + controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ), + ) elif provider_type == ToolProviderType.APP: - raise NotImplementedError('app provider not implemented') + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type.value} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @classmethod def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: """ - init runtime parameter + init runtime parameter """ parameter_value = parameters.get(parameter_rule.name) if not parameter_value and parameter_value != 0: @@ -194,14 +219,17 @@ class ToolManager: options = [x.value for x in parameter_rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError( - f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" + ) return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_agent_tool_runtime( + cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER + ) -> Tool: """ - get the agent tool runtime + get the agent tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, @@ -209,7 +237,7 @@ class ToolManager: tool_name=agent_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT + tool_invoke_from=ToolInvokeFrom.AGENT, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -229,20 +257,27 @@ class ToolManager: tool_runtime=tool_entity, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, - identity_id=f'AGENT.{app_id}' + identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - + if not tool_entity.runtime: raise Exception("tool missing runtime") - + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: """ - get the workflow tool runtime + get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, @@ -250,7 +285,7 @@ class ToolManager: tool_name=workflow_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW + tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -267,7 +302,7 @@ class ToolManager: tool_runtime=tool_entity, provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, - identity_id=f'WORKFLOW.{app_id}.{node_id}' + identity_id=f"WORKFLOW.{app_id}.{node_id}", ) if runtime_parameters: @@ -275,31 +310,37 @@ class ToolManager: if not tool_entity.runtime: raise Exception("tool missing runtime") - + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @classmethod def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ - get the absolute path of the icon of the builtin provider + get the absolute path of the icon of the builtin provider - :param provider: the name of the provider + :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon + :return: the absolute path of the icon, the mime type of the icon """ # get provider provider_controller = cls.get_builtin_provider(provider) - absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', - provider_controller.identity.icon) + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "provider", + "builtin", + provider, + "_assets", + provider_controller.identity.icon, + ) # check if the icon exists if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or 'application/octet-stream' + mime_type = mime_type or "application/octet-stream" return absolute_path, mime_type @@ -320,23 +361,29 @@ class ToolManager: @classmethod def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ - list all the builtin providers + list all the builtin providers """ - for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): - if provider_path.startswith('__'): + for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin")): + if provider_path.startswith("__"): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)): - if provider_path.startswith('__'): + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin", provider_path)): + if provider_path.startswith("__"): continue # init provider try: provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider_path, f'{provider_path}.py'), - parent_type=BuiltinToolProviderController) + module_name=f"core.tools.provider.builtin.{provider_path}.{provider_path}", + script_path=path.join( + path.dirname(path.realpath(__file__)), + "provider", + "builtin", + provider_path, + f"{provider_path}.py", + ), + parent_type=BuiltinToolProviderController, + ) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider for tool in provider.get_tools(): @@ -344,7 +391,7 @@ class ToolManager: yield provider except Exception as e: - logger.error(f'load builtin provider {provider} error: {e}') + logger.error(f"load builtin provider {provider} error: {e}") continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -362,11 +409,11 @@ class ToolManager: @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ - get the tool label + get the tool label - :param tool_name: the name of the tool + :param tool_name: the name of the tool - :return: the label of the tool + :return: the label of the tool """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers @@ -378,75 +425,78 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: + def user_list_providers( + cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral + ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} filters = [] if not typ: - filters.extend(['builtin', 'api', 'workflow']) + filters.extend(["builtin", "api", "workflow"]) else: filters.append(typ) - if 'builtin' in filters: - + if "builtin" in filters: # get builtin providers builtin_providers = cls.list_builtin_providers() # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None + (x for x in db_builtin_providers if x.provider == provider), None ) # append builtin providers for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore - data=provider, - name_func=lambda x: x.identity.name + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + data=provider, + name_func=lambda x: x.identity.name, ): continue user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + decrypt_credentials=False, ) result_providers[provider.identity.name] = user_provider # get db api providers - if 'api' in filters: - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) - api_provider_controllers = [{ - 'provider': provider, - 'controller': ToolTransformService.api_provider_to_controller(provider) - } for provider in db_api_providers] + api_provider_controllers = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] # get labels - labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) for api_provider_controller in api_provider_controllers: user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller['controller'], - db_provider=api_provider_controller['provider'], + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], decrypt_credentials=False, - labels=labels.get(api_provider_controller['controller'].provider_id, []) + labels=labels.get(api_provider_controller["controller"].provider_id, []), ) - result_providers[f'api_provider.{user_provider.name}'] = user_provider + result_providers[f"api_provider.{user_provider.name}"] = user_provider - if 'workflow' in filters: + if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ - filter(WorkflowToolProvider.tenant_id == tenant_id).all() + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) workflow_provider_controllers = [] for provider in workflow_providers: @@ -465,32 +515,36 @@ class ToolManager: provider_controller=provider_controller, labels=labels.get(provider_controller.provider_id, []), ) - result_providers[f'workflow_provider.{user_provider.name}'] = user_provider + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod - def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: """ - get the api provider + get the api provider - :param provider_name: the name of the provider + :param provider_name: the name of the provider - :return: the provider controller, the credentials + :return: the provider controller, the credentials """ - provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ).first() + provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'api provider {provider_id} not found') + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") controller = ApiToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) controller.load_bundled_tools(provider.tools) @@ -499,18 +553,22 @@ class ToolManager: @classmethod def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ - get api provider + get api provider """ """ get tool provider """ - provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider_obj: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider_obj is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f"you have not added provider {provider}") try: credentials = json.loads(provider_obj.credentials_str) or {} @@ -519,14 +577,15 @@ class ToolManager: # package tool provider controller controller = ApiToolProviderController.from_db( - provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider_obj, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, config=controller.get_credentials_schema(), provider_type=controller.provider_type.value, - provider_identity=controller.identity.name + provider_identity=controller.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) @@ -535,66 +594,66 @@ class ToolManager: try: icon = json.loads(provider_obj.icon) except: - icon = { - "background": "#252525", - "content": "\ud83d\ude01" - } + icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder({ - 'schema_type': provider_obj.schema_type, - 'schema': provider_obj.schema, - 'tools': provider_obj.tools, - 'icon': icon, - 'description': provider_obj.description, - 'credentials': masked_credentials, - 'privacy_policy': provider_obj.privacy_policy, - 'custom_disclaimer': provider_obj.custom_disclaimer, - 'labels': labels, - }) + return jsonable_encoder( + { + "schema_type": provider_obj.schema_type, + "schema": provider_obj.schema, + "tools": provider_obj.tools, + "icon": icon, + "description": provider_obj.description, + "credentials": masked_credentials, + "privacy_policy": provider_obj.privacy_policy, + "custom_disclaimer": provider_obj.custom_disclaimer, + "labels": labels, + } + ) @classmethod def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]: """ - get the tool icon + get the tool icon - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: """ provider_type = provider_type provider_id = provider_id if provider_type == ToolProviderType.BUILT_IN: - return (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/" - + provider_id - + "/icon") + return ( + dify_config.CONSOLE_API_URL + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon" + ) elif provider_type == ToolProviderType.API: try: - api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.id == provider_id - ).first() + api_provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) if not api_provider: raise ValueError("api tool not found") - + return json.loads(api_provider.icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } + return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + workflow_provider: WorkflowToolProvider | None = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if workflow_provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return json.loads(workflow_provider.icon) else: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 2f4414c114..f3fce03f8c 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -104,10 +104,12 @@ class ProviderConfigEncrypter(BaseModel): ) cache.delete() + class ToolParameterConfigurationManager(BaseModel): """ Tool parameter configuration manager """ + tenant_id: str tool_runtime: Tool provider_name: str @@ -155,15 +157,19 @@ class ToolParameterConfigurationManager(BaseModel): current_parameters = self._merge_parameters() for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: if len(parameters[parameter.name]) > 6: - parameters[parameter.name] = \ - parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) + \ - parameters[parameter.name][-2:] + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) else: - parameters[parameter.name] = '*' * len(parameters[parameter.name]) + parameters[parameter.name] = "*" * len(parameters[parameter.name]) return parameters @@ -179,7 +185,10 @@ class ToolParameterConfigurationManager(BaseModel): parameters = self._deep_copy(parameters) for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted @@ -197,7 +206,7 @@ class ToolParameterConfigurationManager(BaseModel): provider=f'{self.provider_type.value}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cached_parameters = cache.get() if cached_parameters: @@ -208,7 +217,10 @@ class ToolParameterConfigurationManager(BaseModel): has_secret_input = False for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: try: has_secret_input = True @@ -227,6 +239,6 @@ class ToolParameterConfigurationManager(BaseModel): provider=f'{self.provider_type.value}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cache.delete() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index e6b288868f..44803d7d65 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -17,8 +17,9 @@ class FeishuRequest: redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) return res.get("tenant_access_token") - def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, - params: dict = None): + def _send_request( + self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None + ): headers = { "Content-Type": "application/json", "user-agent": "Dify", @@ -42,10 +43,7 @@ class FeishuRequest: } """ url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token" - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} res = self._send_request(url, require_token=False, payload=payload) return res @@ -76,11 +74,7 @@ class FeishuRequest: def write_document(self, document_id: str, content: str, position: str = "start") -> dict: url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" - payload = { - "document_id": document_id, - "content": content, - "position": position - } + payload = {"document_id": document_id, "content": content, "position": position} res = self._send_request(url, payload=payload) return res.get("data") @@ -95,7 +89,7 @@ class FeishuRequest: "content": "云文档\n多人实时协同,插入一切元素。不仅是在线文档,更是强大的创作和互动工具\n云文档:专为协作而生\n" } } - """ + """ # noqa: E501 params = { "document_id": document_id, } diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index f75b545cbc..73c7ef44b3 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -12,23 +12,24 @@ logger = logging.getLogger(__name__) class ToolFileMessageTransformer: @classmethod - def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None], - user_id: str, - tenant_id: str, - conversation_id: Optional[str] = None) -> Generator[ToolInvokeMessage, None, None]: + def transform_tool_invoke_messages( + cls, + messages: Generator[ToolInvokeMessage, None, None], + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: """ Transform tool message and handle file download """ for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - yield message - elif message.type == ToolInvokeMessage.MessageType.LINK: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: yield message elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image try: if not conversation_id: - raise + raise assert isinstance(message.message, ToolInvokeMessage.TextMessage) @@ -61,24 +62,25 @@ class ToolFileMessageTransformer: # get mime type and save blob to storage assert message.meta - mimetype = message.meta.get('mime_type', 'octet/stream') + mimetype = message.meta.get("mime_type", "octet/stream") # if message is str, encode it to bytes if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") file = ToolFileManager.create_file_by_raw( - user_id=user_id, tenant_id=tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message.blob, - mimetype=mimetype + mimetype=mimetype, ) extension = guess_extension(file.mimetype) or ".bin" url = cls.get_tool_file_url(file.id, extension) # check if file is image - if 'image' in mimetype: + if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), @@ -95,10 +97,11 @@ class ToolFileMessageTransformer: elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: assert message.meta - file_var: FileVar | None = message.meta.get('file_var') + file_var: FileVar | None = message.meta.get("file_var") if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - assert file_var.related_id and file_var.extension + assert file_var.related_id + assert file_var.extension url = cls.get_tool_file_url(file_var.related_id, file_var.extension) if file_var.type == FileType.IMAGE: diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e8ef47823..4e226810d6 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -1,7 +1,7 @@ """ - For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - Therefore, a model manager is needed to list/invoke/validate models. +Therefore, a model manager is needed to list/invoke/validate models. """ import json @@ -27,52 +27,49 @@ from models.tools import ToolModelInvoke class InvokeModelError(Exception): pass + class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, ) -> int: """ - get max llm context tokens of the model + get max llm context tokens of the model """ model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if not schema: - raise InvokeModelError('No model schema found') + raise InvokeModelError("No model schema found") max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 - + return max_tokens @staticmethod - def calculate_tokens( - tenant_id: str, - prompt_messages: list[PromptMessage] - ) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: """ - calculate tokens from prompt messages and model parameters + calculate tokens from prompt messages and model parameters """ # get model instance model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + # get tokens tokens = model_instance.get_llm_num_tokens(prompt_messages) @@ -80,9 +77,7 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, - tool_type: str, tool_name: str, - prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context @@ -103,15 +98,16 @@ class ModelInvocationUtils: model_manager = ModelManager() # get model instance model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) # get prompt tokens prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { - 'temperature': 0.8, - 'top_p': 0.8, + "temperature": 0.8, + "top_p": 0.8, } # create tool model invoke @@ -123,14 +119,14 @@ class ModelInvocationUtils: tool_name=tool_name, model_parameters=json.dumps(model_parameters), prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response='', + model_response="", prompt_tokens=prompt_tokens, answer_tokens=0, answer_unit_price=0, answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", ) db.session.add(tool_model_invoke) @@ -140,20 +136,24 @@ class ModelInvocationUtils: response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=[], stop=[], stream=False, user=user_id, callbacks=[] + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: - raise InvokeModelError(f'Invoke rate limit error: {e}') + raise InvokeModelError(f"Invoke rate limit error: {e}") except InvokeBadRequestError as e: - raise InvokeModelError(f'Invoke bad request error: {e}') + raise InvokeModelError(f"Invoke bad request error: {e}") except InvokeConnectionError as e: - raise InvokeModelError(f'Invoke connection error: {e}') + raise InvokeModelError(f"Invoke connection error: {e}") except InvokeAuthorizationError as e: - raise InvokeModelError('Invoke authorization error') + raise InvokeModelError("Invoke authorization error") except InvokeServerUnavailableError as e: - raise InvokeModelError(f'Invoke server unavailable error: {e}') + raise InvokeModelError(f"Invoke server unavailable error: {e}") except Exception as e: - raise InvokeModelError(f'Invoke error: {e}') + raise InvokeModelError(f"Invoke error: {e}") # update tool model invoke tool_model_invoke.model_response = response.message.content diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 882e276afe..b02a4f75d0 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,3 @@ - import re import uuid from json import dumps as json_dumps @@ -16,54 +15,56 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]: + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} # set description to extra_info - extra_info['description'] = openapi['info'].get('description', '') + extra_info["description"] = openapi["info"].get("description", "") - if len(openapi['servers']) == 0: - raise ToolProviderNotFoundError('No server found in the openapi yaml.') + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") - server_url = openapi['servers'][0]['url'] + server_url = openapi["servers"][0]["url"] # list all interfaces interfaces = [] - for path, path_item in openapi['paths'].items(): - methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace'] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] for method in methods: if method in path_item: - interfaces.append({ - 'path': path, - 'method': method, - 'operation': path_item[method], - }) + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) # get all parameters bundles = [] for interface in interfaces: # convert parameters parameters = [] - if 'parameters' in interface['operation']: - for parameter in interface['operation']['parameters']: + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: tool_parameter = ToolParameter( - name=parameter['name'], - label=I18nObject( - en_US=parameter['name'], - zh_Hans=parameter['name'] - ), + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), human_description=I18nObject( - en_US=parameter.get('description', ''), - zh_Hans=parameter.get('description', '') + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, - required=parameter.get('required', False), + required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, - llm_description=parameter.get('description'), - default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, ) - + # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) if typ: @@ -72,44 +73,40 @@ class ApiBasedToolSchemaParser: parameters.append(tool_parameter) # create tool bundle # check if there is a request body - if 'requestBody' in interface['operation']: - request_body = interface['operation']['requestBody'] - if 'content' in request_body: - for content_type, content in request_body['content'].items(): + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): # if there is a reference, get the reference and overwrite the content - if 'schema' not in content: + if "schema" not in content: continue - if '$ref' in content['schema']: + if "$ref" in content["schema"]: # get the reference root = openapi - reference = content['schema']['$ref'].split('/')[1:] + reference = content["schema"]["$ref"].split("/")[1:] for ref in reference: root = root[ref] # overwrite the content - interface['operation']['requestBody']['content'][content_type]['schema'] = root + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root # parse body parameters - if 'schema' in interface['operation']['requestBody']['content'][content_type]: - body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): tool = ToolParameter( name=name, - label=I18nObject( - en_US=name, - zh_Hans=name - ), + label=I18nObject(en_US=name, zh_Hans=name), human_description=I18nObject( - en_US=property.get('description', ''), - zh_Hans=property.get('description', '') + en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, required=name in required, form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get('description', ''), - default=property.get('default', None), + llm_description=property.get("description", ""), + default=property.get("default", None), ) # check if there is a type @@ -127,173 +124,177 @@ class ApiBasedToolSchemaParser: parameters_count[parameter.name] += 1 for name, count in parameters_count.items(): if count > 1: - warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." # check if there is a operation id, use $path_$method as operation id if not - if 'operationId' not in interface['operation']: + if "operationId" not in interface["operation"]: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = interface['path'] - if interface['path'].startswith('/'): - path = interface['path'][1:] + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: path = str(uuid.uuid4()) - - interface['operation']['operationId'] = f'{path}_{interface["method"]}' - bundles.append(ApiToolBundle( - server_url=server_url + interface['path'], - method=interface['method'], - summary=interface['operation']['description'] if 'description' in interface['operation'] else - interface['operation'].get('summary', None), - operation_id=interface['operation']['operationId'], - parameters=parameters, - author='', - icon=None, - openapi=interface['operation'], - )) + interface["operation"]["operationId"] = f'{path}_{interface["method"]}' + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) return bundles - + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None - if 'type' in parameter: - typ = parameter['type'] - elif 'schema' in parameter and 'type' in parameter['schema']: - typ = parameter['schema']['type'] - - if typ == 'integer' or typ == 'number': + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ in {"integer", "number"}: return ToolParameter.ToolParameterType.NUMBER - elif typ == 'boolean': + elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN - elif typ == 'string': + elif typ == "string": return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]: + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: """ - parse openapi yaml to tool bundle + parse openapi yaml to tool bundle - :param yaml: the yaml string - :return: the tool bundle + :param yaml: the yaml string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} openapi: dict = safe_load(yaml) if openapi is None: - raise ToolApiSchemaError('Invalid openapi yaml.') + raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - + @staticmethod def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: warning = warning or {} """ - parse swagger to openapi + parse swagger to openapi - :param swagger: the swagger dict - :return: the openapi dict + :param swagger: the swagger dict + :return: the openapi dict """ # convert swagger to openapi - info = swagger.get('info', { - 'title': 'Swagger', - 'description': 'Swagger', - 'version': '1.0.0' - }) + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) - servers = swagger.get('servers', []) + servers = swagger.get("servers", []) if len(servers) == 0: - raise ToolApiSchemaError('No server found in the swagger yaml.') + raise ToolApiSchemaError("No server found in the swagger yaml.") openapi = { - 'openapi': '3.0.0', - 'info': { - 'title': info.get('title', 'Swagger'), - 'description': info.get('description', 'Swagger'), - 'version': info.get('version', '1.0.0') + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), }, - 'servers': swagger['servers'], - 'paths': {}, - 'components': { - 'schemas': {} - } + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, } # check paths - if 'paths' not in swagger or len(swagger['paths']) == 0: - raise ToolApiSchemaError('No paths found in the swagger yaml.') + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") # convert paths - for path, path_item in swagger['paths'].items(): - openapi['paths'][path] = {} + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} for method, operation in path_item.items(): - if 'operationId' not in operation: - raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.') - - if ('summary' not in operation or len(operation['summary']) == 0) and \ - ('description' not in operation or len(operation['description']) == 0): - warning['missing_summary'] = f'No summary or description found in operation {method} {path}.' - - openapi['paths'][path][method] = { - 'operationId': operation['operationId'], - 'summary': operation.get('summary', ''), - 'description': operation.get('description', ''), - 'parameters': operation.get('parameters', []), - 'responses': operation.get('responses', {}), + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), } - if 'requestBody' in operation: - openapi['paths'][path][method]['requestBody'] = operation['requestBody'] + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger['definitions'].items(): - openapi['components']['schemas'][name] = definition + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]: + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: """ - parse openapi plugin yaml to tool bundle + parse openapi plugin yaml to tool bundle - :param json: the json string - :return: the tool bundle + :param json: the json string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} try: openai_plugin = json_loads(json) - api = openai_plugin['api'] - api_url = api['url'] - api_type = api['type'] + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] except: - raise ToolProviderNotFoundError('Invalid openai plugin json.') - - if api_type != 'openapi': - raise ToolNotSupportedError('Only openapi is supported now.') - + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + # get openapi yaml - response = get(api_url, headers={ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' - }, timeout=5) + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) if response.status_code != 200: - raise ToolProviderNotFoundError('cannot get openapi yaml from url.') - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) - - @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]: - """ - auto parse to tool bundle + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") - :param content: the content - :return: tools bundle, schema_type + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + + @staticmethod + def auto_parse_to_tool_bundle( + content: str, extra_info: dict | None = None, warning: dict | None = None + ) -> tuple[list[ApiToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :return: tools bundle, schema_type """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -302,7 +303,7 @@ class ApiBasedToolSchemaParser: loaded_content = None json_error = None yaml_error = None - + try: loaded_content = json_loads(content) except JSONDecodeError as e: @@ -314,34 +315,48 @@ class ApiBasedToolSchemaParser: except YAMLError as e: yaml_error = e if loaded_content is None: - raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," + f" yaml error: {str(yaml_error)}" + ) swagger_error = None openapi_error = None openapi_plugin_error = None schema_type = None - + try: - openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning) + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.OPENAPI.value return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - + # openai parse error, fallback to swagger try: - converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning) + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.SWAGGER.value - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type except ToolApiSchemaError as e: swagger_error = e - + # swagger parse error, fallback to openai plugin try: - openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning) + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e - raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," + f" openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py index 6f88eeaa0a..6f7610651c 100644 --- a/api/core/tools/utils/tool_parameter_converter.py +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -7,16 +7,18 @@ class ToolParameterConverter: @staticmethod def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: - return 'string' + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + return "string" case ToolParameter.ToolParameterType.BOOLEAN: - return 'boolean' + return "boolean" case ToolParameter.ToolParameterType.NUMBER: - return 'number' + return "number" case _: raise ValueError(f"Unsupported parameter type {parameter_type}") @@ -26,11 +28,13 @@ class ToolParameterConverter: # convert tool parameter config to correct type try: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): if value is None: - return '' + return "" else: return value if isinstance(value, str) else str(value) @@ -41,9 +45,9 @@ class ToolParameterConverter: # Allowed YAML boolean value strings: https://yaml.org/type/bool.html # and also '0' for False and '1' for True match value.lower(): - case 'true' | 'yes' | 'y' | '1': + case "true" | "yes" | "y" | "1": return True - case 'false' | 'no' | 'n' | '0': + case "false" | "no" | "n" | "0": return False case _: return bool(value) @@ -53,8 +57,8 @@ class ToolParameterConverter: case ToolParameter.ToolParameterType.NUMBER: if isinstance(value, int) | isinstance(value, float): return value - elif isinstance(value, str) and value != '': - if '.' in value: + elif isinstance(value, str) and value != "": + if "." in value: return float(value) else: return int(value) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 150941924d..1ced7d0488 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -8,6 +8,7 @@ import subprocess import tempfile import unicodedata from contextlib import contextmanager +from pathlib import Path from urllib.parse import unquote import chardet @@ -32,13 +33,14 @@ TEXT: def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor: cursor + max_length] + return text[cursor : cursor + max_length] def get_url(url: str, user_agent: str = None) -> str: """Fetch URL and return the contents as a string.""" headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" } if user_agent: headers["User-Agent"] = user_agent @@ -49,15 +51,15 @@ def get_url(url: str, user_agent: str = None) -> str: if response.status_code == 200: # check content-type - content_type = response.headers.get('Content-Type') + content_type = response.headers.get("Content-Type") if content_type: - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() else: - content_disposition = response.headers.get('Content-Disposition', '') + content_disposition = response.headers.get("Content-Disposition", "") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - extension = re.search(r'\.(\w+)$', filename) + extension = re.search(r"\.(\w+)$", filename) if extension: main_content_type = mimetypes.guess_type(filename)[0] @@ -78,7 +80,7 @@ def get_url(url: str, user_agent: str = None) -> str: # Detect encoding using chardet detected_encoding = chardet.detect(response.content) - encoding = detected_encoding['encoding'] + encoding = detected_encoding["encoding"] if encoding: try: content = response.content.decode(encoding) @@ -89,35 +91,34 @@ def get_url(url: str, user_agent: str = None) -> str: a = extract_using_readabilipy(content) - if not a['plain_text'] or not a['plain_text'].strip(): - return '' + if not a["plain_text"] or not a["plain_text"].strip(): + return "" res = FULL_TEMPLATE.format( - title=a['title'], - authors=a['byline'], - publish_date=a['date'], + title=a["title"], + authors=a["byline"], + publish_date=a["date"], top_image="", - text=a['plain_text'] if a['plain_text'] else "", + text=a["plain_text"] or "", ) return res def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: f_html.write(html) f_html.close() html_path = f_html.name # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") with chdir(jsdir): subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) # Read output of call to Readability.parse() from JSON file and return as Python dictionary - with open(article_json_path, encoding="utf-8") as json_file: - input_json = json.loads(json_file.read()) + input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) # Deleting files after processing os.unlink(article_json_path) @@ -129,7 +130,7 @@ def extract_using_readabilipy(html): "date": None, "content": None, "plain_content": None, - "plain_text": None + "plain_text": None, } # Populate article fields from readability fields where present if input_json: @@ -145,7 +146,7 @@ def extract_using_readabilipy(html): article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) if input_json.get("textContent"): article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) return article_json @@ -158,6 +159,7 @@ def find_module_path(module_name): return None + @contextmanager def chdir(path): """Change directory in context and return to original on exit""" @@ -172,12 +174,14 @@ def chdir(path): def extract_text_blocks_as_plain_text(paragraph_html): # Load article as DOM - soup = BeautifulSoup(paragraph_html, 'html.parser') + soup = BeautifulSoup(paragraph_html, "html.parser") # Select all lists - list_elements = soup.find_all(['ul', 'ol']) + list_elements = soup.find_all(["ul", "ol"]) # Prefix text in all list items with "* " and make lists paragraphs for list_element in list_elements: - plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) list_element.string = plain_items list_element.name = "p" # Select all text blocks @@ -204,7 +208,7 @@ def plain_text_leaf_node(element): def plain_content(readability_content, content_digests, node_indexes): # Load article as DOM - soup = BeautifulSoup(readability_content, 'html.parser') + soup = BeautifulSoup(readability_content, "html.parser") # Make all elements plain elements = plain_elements(soup.contents, content_digests, node_indexes) if node_indexes: @@ -217,8 +221,7 @@ def plain_content(readability_content, content_digests, node_indexes): def plain_elements(elements, content_digests, node_indexes): # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) - for element in elements] + elements = [plain_element(element, content_digests, node_indexes) for element in elements] if content_digests: # Add content digest attribute to nodes elements = [add_content_digest(element) for element in elements] @@ -258,11 +261,9 @@ def add_node_indexes(element, node_index="0"): # Add index to current element element["data-node-index"] = node_index # Add index to child elements - for local_idx, child in enumerate( - [c for c in element.contents if not is_text(c)], start=1): + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format( - stem=node_index, local=local_idx) + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) add_node_indexes(child, node_index=child_index) return element @@ -284,11 +285,16 @@ def strip_control_characters(text): # [Cn]: Other, Not Assigned # [Co]: Other, Private Use # [Cs]: Other, Surrogate - control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} - retained_chars = ['\t', '\n', '\r', '\f'] + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] # Remove non-printing control characters - return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) def normalize_unicode(text): @@ -305,8 +311,9 @@ def normalize_whitespace(text): text = text.strip() return text + def is_leaf(element): - return (element.name in ['p', 'li']) + return element.name in {"p", "li"} def is_text(element): @@ -330,7 +337,7 @@ def content_digest(element): if trimmed_string == "": digest = "" else: - digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() else: contents = element.contents num_contents = len(contents) @@ -343,9 +350,8 @@ def content_digest(element): else: # Build content digest from the "non-empty" digests of child nodes digest = hashlib.sha256() - child_digests = list( - filter(lambda x: x != "", [content_digest(content) for content in contents])) + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) for child in child_digests: - digest.update(child.encode('utf-8')) + digest.update(child.encode("utf-8")) digest = digest.hexdigest() return digest diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index b8237fd043..7f20605024 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -12,27 +12,25 @@ class WorkflowToolConfigurationUtils: """ for configuration in configurations: if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError('invalid parameter configuration') + raise ValueError("invalid parameter configuration") @classmethod def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]: """ get workflow graph variables """ - nodes = graph.get('nodes', []) - start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) if not start_node: return [] - return [ - VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) - ] - + return [VariableEntity(**variable) for variable in start_node.get("data", {}).get("variables", [])] + @classmethod - def check_is_synced(cls, - variables: list[VariableEntity], - tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ) -> None: """ check is synced @@ -41,8 +39,8 @@ class WorkflowToolConfigurationUtils: variable_names = [variable.variable for variable in variables] if len(tool_configurations) != len(variables): - raise ValueError('parameter configuration mismatch, please republish the tool to update') - + raise ValueError("parameter configuration mismatch, please republish the tool to update") + for parameter in tool_configurations: if parameter.name not in variable_names: raise ValueError('parameter configuration mismatch, please republish the tool to update') diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index f751c43096..99b9f80499 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -18,12 +18,12 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :return: an object of the YAML content """ try: - with open(file_path, encoding='utf-8') as yaml_file: + with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) - return yaml_content if yaml_content else default_value + return yaml_content or default_value except Exception as e: - raise YAMLError(f'Failed to load YAML file {file_path}: {e}') + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") except Exception as e: if ignore_error: return default_value diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 9015eea85c..83086d1afc 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -5,10 +5,7 @@ from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_event( - self, - event: GraphEngineEvent - ) -> None: + def on_event(self, event: GraphEngineEvent) -> None: """ Published event """ diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index e7e6710cbd..2a864dd7a8 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -8,9 +8,11 @@ class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + class BaseIterationNodeData(BaseNodeData): start_node_id: Optional[str] = None + class BaseIterationState(BaseModel): iteration_node_id: str index: int @@ -19,4 +21,4 @@ class BaseIterationState(BaseModel): class MetaData(BaseModel): pass - metadata: MetaData \ No newline at end of file + metadata: MetaData diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 5e2a5cb466..5353b99ed3 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -12,28 +12,28 @@ class NodeType(Enum): Node Types. """ - START = 'start' - END = 'end' - ANSWER = 'answer' - LLM = 'llm' - KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' - IF_ELSE = 'if-else' - CODE = 'code' - TEMPLATE_TRANSFORM = 'template-transform' - QUESTION_CLASSIFIER = 'question-classifier' - HTTP_REQUEST = 'http-request' - TOOL = 'tool' - VARIABLE_AGGREGATOR = 'variable-aggregator' + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = 'variable-assigner' - LOOP = 'loop' - ITERATION = 'iteration' - ITERATION_START = 'iteration-start' # fake start node for iteration - PARAMETER_EXTRACTOR = 'parameter-extractor' - CONVERSATION_VARIABLE_ASSIGNER = 'assigner' + VARIABLE_ASSIGNER = "variable-assigner" + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # fake start node for iteration + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" @classmethod - def value_of(cls, value: str) -> 'NodeType': + def value_of(cls, value: str) -> "NodeType": """ Get value of given node type. @@ -43,7 +43,7 @@ class NodeType(Enum): for node_type in cls: if node_type.value == value: return node_type - raise ValueError(f'invalid node type value {value}') + raise ValueError(f"invalid node type value {value}") class NodeRunMetadataKey(Enum): @@ -51,16 +51,16 @@ class NodeRunMetadataKey(Enum): Node Run Metadata Key. """ - TOTAL_TOKENS = 'total_tokens' - TOTAL_PRICE = 'total_price' - CURRENCY = 'currency' - TOOL_INFO = 'tool_info' - ITERATION_ID = 'iteration_id' - ITERATION_INDEX = 'iteration_index' - PARALLEL_ID = 'parallel_id' - PARALLEL_START_NODE_ID = 'parallel_start_node_id' - PARENT_PARALLEL_ID = 'parent_parallel_id' - PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' + TOTAL_TOKENS = "total_tokens" + TOTAL_PRICE = "total_price" + CURRENCY = "currency" + TOOL_INFO = "tool_info" + ITERATION_ID = "iteration_id" + ITERATION_INDEX = "iteration_index" + PARALLEL_ID = "parallel_id" + PARALLEL_START_NODE_ID = "parallel_start_node_id" + PARENT_PARALLEL_ID = "parent_parallel_id" + PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" class NodeRunResult(BaseModel): @@ -85,6 +85,7 @@ class UserFrom(Enum): """ User from """ + ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 19d9af2a61..1dfb1852f8 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -5,5 +5,6 @@ class VariableSelector(BaseModel): """ Variable Selector. """ + variable: str value_selector: list[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 48a20d25ae..b94b7f7198 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -23,23 +23,19 @@ class VariablePool(BaseModel): # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: dict[str, dict[int, Segment]] = Field( - description='Variables mapping', - default=defaultdict(dict) + description="Variables mapping", default=defaultdict(dict) ) # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( - description='User inputs', + description="User inputs", ) system_variables: Mapping[SystemVariableKey, Any] = Field( - description='System variables', + description="System variables", ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list - ) + environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list) conversation_variables: Sequence[Variable] | None = None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 4bf4e454bb..0a1eb57de4 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -46,13 +46,16 @@ class WorkflowRunState: current_iteration_state: Optional[BaseIterationState] - def __init__(self, workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int): + def __init__( + self, + workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int, + ): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index 4099def4e2..697392b2a3 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -8,19 +8,13 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta class RunConditionHandler(ABC): - def __init__(self, - init_params: GraphInitParams, - graph: Graph, - condition: RunCondition): + def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition): self.init_params = init_params self.graph = graph self.condition = condition @abstractmethod - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState - ) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index 705eb908b1..af695df7d8 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -4,10 +4,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta class BranchIdentifyRunConditionHandler(RunConditionHandler): - - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index 1edaf92da7..eda5fe079c 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,10 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState - ) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed @@ -22,8 +19,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() input_conditions, group_result = condition_processor.process_conditions( - variable_pool=graph_runtime_state.variable_pool, - conditions=self.condition.conditions + variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions ) # Apply the logical operator for the current case diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py index 2eb2e58bfc..1c9237d82f 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -9,9 +9,7 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition class ConditionManager: @staticmethod def get_condition_handler( - init_params: GraphInitParams, - graph: Graph, - run_condition: RunCondition + init_params: GraphInitParams, graph: Graph, run_condition: RunCondition ) -> RunConditionHandler: """ Get condition handler @@ -22,14 +20,6 @@ class ConditionManager: :return: condition handler """ if run_condition.type == "branch_identify": - return BranchIdentifyRunConditionHandler( - init_params=init_params, - graph=graph, - condition=run_condition - ) + return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition) else: - return ConditionRunConditionHandlerHandler( - init_params=init_params, - graph=graph, - condition=run_condition - ) + return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 49007b870d..1d7e9158d8 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -34,38 +34,25 @@ class Graph(BaseModel): root_node_id: str = Field(..., description="root node id of the graph") node_ids: list[str] = Field(default_factory=list, description="graph node ids") node_id_config_mapping: dict[str, dict] = Field( - default_factory=list, - description="node configs mapping (node id: node config)" + default_factory=list, description="node configs mapping (node id: node config)" ) edge_mapping: dict[str, list[GraphEdge]] = Field( - default_factory=dict, - description="graph edge mapping (source node id: edges)" + default_factory=dict, description="graph edge mapping (source node id: edges)" ) reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( - default_factory=dict, - description="reverse graph edge mapping (target node id: edges)" + default_factory=dict, description="reverse graph edge mapping (target node id: edges)" ) parallel_mapping: dict[str, GraphParallel] = Field( - default_factory=dict, - description="graph parallel mapping (parallel id: parallel)" + default_factory=dict, description="graph parallel mapping (parallel id: parallel)" ) node_parallel_mapping: dict[str, str] = Field( - default_factory=dict, - description="graph node parallel mapping (node id: parallel id)" - ) - answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( - ..., - description="answer stream generate routes" - ) - end_stream_param: EndStreamParam = Field( - ..., - description="end stream param" + default_factory=dict, description="graph node parallel mapping (node id: parallel id)" ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes") + end_stream_param: EndStreamParam = Field(..., description="end stream param") @classmethod - def init(cls, - graph_config: Mapping[str, Any], - root_node_id: Optional[str] = None) -> "Graph": + def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": """ Init graph @@ -74,7 +61,7 @@ class Graph(BaseModel): :return: graph """ # edge configs - edge_configs = graph_config.get('edges') + edge_configs = graph_config.get("edges") if edge_configs is None: edge_configs = [] @@ -85,45 +72,36 @@ class Graph(BaseModel): reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() for edge_config in edge_configs: - source_node_id = edge_config.get('source') + source_node_id = edge_config.get("source") if not source_node_id: continue if source_node_id not in edge_mapping: edge_mapping[source_node_id] = [] - target_node_id = edge_config.get('target') + target_node_id = edge_config.get("target") if not target_node_id: continue if target_node_id not in reverse_edge_mapping: reverse_edge_mapping[target_node_id] = [] - # is target node id in source node id edge mapping - if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]): - continue - target_edge_ids.add(target_node_id) # parse run condition run_condition = None - if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': - run_condition = RunCondition( - type='branch_identify', - branch_identify=edge_config.get('sourceHandle') - ) + if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": + run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) graph_edge = GraphEdge( - source_node_id=source_node_id, - target_node_id=target_node_id, - run_condition=run_condition + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) edge_mapping[source_node_id].append(graph_edge) reverse_edge_mapping[target_node_id].append(graph_edge) # node configs - node_configs = graph_config.get('nodes') + node_configs = graph_config.get("nodes") if not node_configs: raise ValueError("Graph must have at least one node") @@ -133,7 +111,7 @@ class Graph(BaseModel): root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} for node_config in node_configs: - node_id = node_config.get('id') + node_id = node_config.get("id") if not node_id: continue @@ -142,30 +120,29 @@ class Graph(BaseModel): all_node_id_config_mapping[node_id] = node_config - root_node_ids = [node_config.get('id') for node_config in root_node_configs] + root_node_ids = [node_config.get("id") for node_config in root_node_configs] # fetch root node if not root_node_id: # if no root node id, use the START type node as root node - root_node_id = next((node_config.get("id") for node_config in root_node_configs - if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + root_node_id = next( + ( + node_config.get("id") + for node_config in root_node_configs + if node_config.get("data", {}).get("type", "") == NodeType.START.value + ), + None, + ) if not root_node_id or root_node_id not in root_node_ids: raise ValueError(f"Root node id {root_node_id} not found in the graph") - + # Check whether it is connected to the previous node - cls._check_connected_to_previous_node( - route=[root_node_id], - edge_mapping=edge_mapping - ) + cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping) # fetch all node ids from root node node_ids = [root_node_id] - cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=root_node_id - ) + cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id) node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} @@ -177,29 +154,26 @@ class Graph(BaseModel): reverse_edge_mapping=reverse_edge_mapping, start_node_id=root_node_id, parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, ) # Check if it exceeds N layers of parallel for parallel in parallel_mapping.values(): if parallel.parent_parallel_id: cls._check_exceed_parallel_limit( - parallel_mapping=parallel_mapping, - level_limit=3, - parent_parallel_id=parallel.parent_parallel_id + parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id ) # init answer stream generate routes answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping + node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping ) # init end stream param end_stream_param = EndStreamGeneratorRouter.init( node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, ) # init graph @@ -212,14 +186,14 @@ class Graph(BaseModel): parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, answer_stream_generate_routes=answer_stream_generate_routes, - end_stream_param=end_stream_param + end_stream_param=end_stream_param, ) return graph - def add_extra_edge(self, source_node_id: str, - target_node_id: str, - run_condition: Optional[RunCondition] = None) -> None: + def add_extra_edge( + self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None + ) -> None: """ Add extra edge to the graph @@ -237,9 +211,7 @@ class Graph(BaseModel): return graph_edge = GraphEdge( - source_node_id=source_node_id, - target_node_id=target_node_id, - run_condition=run_condition + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) self.edge_mapping[source_node_id].append(graph_edge) @@ -252,19 +224,18 @@ class Graph(BaseModel): """ leaf_node_ids = [] for node_id in self.node_ids: - if node_id not in self.edge_mapping: - leaf_node_ids.append(node_id) - elif (len(self.edge_mapping[node_id]) == 1 - and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): + if node_id not in self.edge_mapping or ( + len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id + ): leaf_node_ids.append(node_id) return leaf_node_ids @classmethod - def _recursively_add_node_ids(cls, - node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - node_id: str) -> None: + def _recursively_add_node_ids( + cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str + ) -> None: """ Recursively add node ids @@ -278,17 +249,11 @@ class Graph(BaseModel): node_ids.append(graph_edge.target_node_id) cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=graph_edge.target_node_id + node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id ) @classmethod - def _check_connected_to_previous_node( - cls, - route: list[str], - edge_mapping: dict[str, list[GraphEdge]] - ) -> None: + def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None: """ Check whether it is connected to the previous node """ @@ -299,9 +264,11 @@ class Graph(BaseModel): continue if graph_edge.target_node_id in route: - raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") + raise ValueError( + f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." + ) - new_route = route[:] + new_route = route.copy() new_route.append(graph_edge.target_node_id) cls._check_connected_to_previous_node( route=new_route, @@ -316,7 +283,7 @@ class Graph(BaseModel): start_node_id: str, parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str], - parent_parallel: Optional[GraphParallel] = None + parent_parallel: Optional[GraphParallel] = None, ) -> None: """ Recursively add parallel ids @@ -331,119 +298,198 @@ class Graph(BaseModel): parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = [] + parallel_branch_node_ids = {} condition_edge_mappings = {} for graph_edge in target_node_edges: if graph_edge.run_condition is None: - parallel_branch_node_ids.append(graph_edge.target_node_id) + if "default" not in parallel_branch_node_ids: + parallel_branch_node_ids["default"] = [] + + parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if not condition_hash in condition_edge_mappings: + if condition_hash not in condition_edge_mappings: condition_edge_mappings[condition_hash] = [] condition_edge_mappings[condition_hash].append(graph_edge) - for _, graph_edges in condition_edge_mappings.items(): + for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: + if condition_hash not in parallel_branch_node_ids: + parallel_branch_node_ids[condition_hash] = [] + for graph_edge in graph_edges: - parallel_branch_node_ids.append(graph_edge.target_node_id) + parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) - # any target node id in node_parallel_mapping - if parallel_branch_node_ids: - parent_parallel_id = parent_parallel.id if parent_parallel else None + condition_parallels = {} + for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items(): + # any target node id in node_parallel_mapping + parallel = None + if condition_parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None - parallel = GraphParallel( - start_from_node_id=start_node_id, - parent_parallel_id=parent_parallel.id if parent_parallel else None, - parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None - ) - parallel_mapping[parallel.id] = parallel + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + ) + parallel_mapping[parallel.id] = parallel + condition_parallels[condition_hash] = parallel - in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( - edge_mapping=edge_mapping, - reverse_edge_mapping=reverse_edge_mapping, - parallel_branch_node_ids=parallel_branch_node_ids - ) + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=condition_parallel_branch_node_ids, + ) - # collect all branches node ids - parallel_node_ids = [] - for _, node_ids in in_branch_node_ids.items(): - for node_id in node_ids: - in_parent_parallel = True - if parent_parallel_id: - in_parent_parallel = False - for parallel_node_id, parallel_id in node_parallel_mapping.items(): - if parallel_id == parent_parallel_id and parallel_node_id == node_id: - in_parent_parallel = True - break + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break - if in_parent_parallel: - parallel_node_ids.append(node_id) - node_parallel_mapping[node_id] = parallel.id + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id - outside_parallel_target_node_ids = set() - for node_id in parallel_node_ids: - if node_id == parallel.start_from_node_id: - continue - - node_edges = edge_mapping.get(node_id) - if not node_edges: - continue - - if len(node_edges) > 1: - continue - - target_node_id = node_edges[0].target_node_id - if target_node_id in parallel_node_ids: - continue - - if parent_parallel_id: - parent_parallel = parallel_mapping.get(parent_parallel_id) - if not parent_parallel: + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: continue - if ( - (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) - or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) - or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) - ): - outside_parallel_target_node_ids.add(target_node_id) + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue - if len(outside_parallel_target_node_ids) == 1: - if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: - parallel.end_to_node_id = None - else: - parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + if len(node_edges) > 1: + continue + + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue + + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue - for graph_edge in target_node_edges: - current_parallel = None - if parallel: - current_parallel = parallel - elif parent_parallel: - if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id): - current_parallel = parent_parallel - else: - # fetch parent parallel's parent parallel - parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id - if parent_parallel_parent_parallel_id: - parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) if ( - parent_parallel_parent_parallel - and ( - not parent_parallel_parent_parallel.end_to_node_id - or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id) + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) ): - current_parallel = parent_parallel_parent_parallel + outside_parallel_target_node_ids.add(target_node_id) - cls._recursively_add_parallels( - edge_mapping=edge_mapping, - reverse_edge_mapping=reverse_edge_mapping, - start_node_id=graph_edge.target_node_id, - parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel - ) + if len(outside_parallel_target_node_ids) == 1: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + if condition_edge_mappings: + for condition_hash, graph_edges in condition_edge_mappings.items(): + for graph_edge in graph_edges: + current_parallel: GraphParallel | None = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + + @classmethod + def _get_current_parallel( + cls, + parallel_mapping: dict[str, GraphParallel], + graph_edge: GraphEdge, + parallel: Optional[GraphParallel] = None, + parent_parallel: Optional[GraphParallel] = None, + ) -> Optional[GraphParallel]: + """ + Get current parallel + """ + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id + or ( + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id + ) + ): + current_parallel = parent_parallel_parent_parallel + + return current_parallel @classmethod def _check_exceed_parallel_limit( @@ -451,7 +497,7 @@ class Graph(BaseModel): parallel_mapping: dict[str, GraphParallel], level_limit: int, parent_parallel_id: str, - current_level: int = 1 + current_level: int = 1, ) -> None: """ Check if it exceeds N layers of parallel @@ -459,25 +505,27 @@ class Graph(BaseModel): parent_parallel = parallel_mapping.get(parent_parallel_id) if not parent_parallel: return - + current_level += 1 if current_level > level_limit: raise ValueError(f"Exceeds {level_limit} layers of parallel") - + if parent_parallel.parent_parallel_id: cls._check_exceed_parallel_limit( parallel_mapping=parallel_mapping, level_limit=level_limit, parent_parallel_id=parent_parallel.parent_parallel_id, - current_level=current_level + current_level=current_level, ) @classmethod - def _recursively_add_parallel_node_ids(cls, - branch_node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - merge_node_id: str, - start_node_id: str) -> None: + def _recursively_add_parallel_node_ids( + cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str, + ) -> None: """ Recursively add node ids @@ -487,21 +535,22 @@ class Graph(BaseModel): :param start_node_id: start node id """ for graph_edge in edge_mapping.get(start_node_id, []): - if (graph_edge.target_node_id != merge_node_id - and graph_edge.target_node_id not in branch_node_ids): + if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids: branch_node_ids.append(graph_edge.target_node_id) cls._recursively_add_parallel_node_ids( branch_node_ids=branch_node_ids, edge_mapping=edge_mapping, merge_node_id=merge_node_id, - start_node_id=graph_edge.target_node_id + start_node_id=graph_edge.target_node_id, ) @classmethod - def _fetch_all_node_ids_in_parallels(cls, - edge_mapping: dict[str, list[GraphEdge]], - reverse_edge_mapping: dict[str, list[GraphEdge]], - parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: + def _fetch_all_node_ids_in_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str], + ) -> dict[str, list[str]]: """ Fetch all node ids in parallels """ @@ -513,7 +562,7 @@ class Graph(BaseModel): cls._recursively_fetch_routes( edge_mapping=edge_mapping, start_node_id=parallel_branch_node_id, - routes_node_ids=routes_node_ids[parallel_branch_node_id] + routes_node_ids=routes_node_ids[parallel_branch_node_id], ) # fetch leaf node ids from routes node ids @@ -529,13 +578,13 @@ class Graph(BaseModel): for branch_node_id2, inner_route2 in routes_node_ids.items(): if ( - branch_node_id != branch_node_id2 + branch_node_id != branch_node_id2 and node_id in inner_route2 and len(reverse_edge_mapping.get(node_id, [])) > 1 and cls._is_node_in_routes( reverse_edge_mapping=reverse_edge_mapping, start_node_id=node_id, - routes_node_ids=routes_node_ids + routes_node_ids=routes_node_ids, ) ): if node_id not in merge_branch_node_ids: @@ -551,23 +600,18 @@ class Graph(BaseModel): for node_id, branch_node_ids in merge_branch_node_ids.items(): for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): - if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: + if (node_id, node_id2) not in duplicate_end_node_ids and ( + node_id2, + node_id, + ) not in duplicate_end_node_ids: duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids - + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after - if cls._is_node2_after_node1( - node1_id=node_id, - node2_id=node_id2, - edge_mapping=edge_mapping - ): + if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): if node_id in merge_branch_node_ids: del merge_branch_node_ids[node_id2] - elif cls._is_node2_after_node1( - node1_id=node_id2, - node2_id=node_id, - edge_mapping=edge_mapping - ): + elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): if node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] @@ -599,16 +643,15 @@ class Graph(BaseModel): branch_node_ids=in_branch_node_ids[branch_node_id], edge_mapping=edge_mapping, merge_node_id=merge_node_id, - start_node_id=branch_node_id + start_node_id=branch_node_id, ) return in_branch_node_ids @classmethod - def _recursively_fetch_routes(cls, - edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - routes_node_ids: list[str]) -> None: + def _recursively_fetch_routes( + cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str] + ) -> None: """ Recursively fetch route """ @@ -621,28 +664,24 @@ class Graph(BaseModel): routes_node_ids.append(graph_edge.target_node_id) cls._recursively_fetch_routes( - edge_mapping=edge_mapping, - start_node_id=graph_edge.target_node_id, - routes_node_ids=routes_node_ids + edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids ) @classmethod - def _is_node_in_routes(cls, - reverse_edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - routes_node_ids: dict[str, list[str]]) -> bool: + def _is_node_in_routes( + cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]] + ) -> bool: """ Recursively check if the node is in the routes """ if start_node_id not in reverse_edge_mapping: return False - + all_routes_node_ids = set() parallel_start_node_ids: dict[str, list[str]] = {} for branch_node_id, node_ids in routes_node_ids.items(): - for node_id in node_ids: - all_routes_node_ids.add(node_id) - + all_routes_node_ids.update(node_ids) + if branch_node_id in reverse_edge_mapping: for graph_edge in reverse_edge_mapping[branch_node_id]: if graph_edge.source_node_id not in parallel_start_node_ids: @@ -655,38 +694,34 @@ class Graph(BaseModel): if set(branch_node_ids) == set(routes_node_ids.keys()): parallel_start_node_id = p_start_node_id return True - + if not parallel_start_node_id: raise Exception("Parallel start node id not found") - + for graph_edge in reverse_edge_mapping[start_node_id]: - if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: + if ( + graph_edge.source_node_id not in all_routes_node_ids + or graph_edge.source_node_id != parallel_start_node_id + ): return False - + return True @classmethod - def _is_node2_after_node1( - cls, - node1_id: str, - node2_id: str, - edge_mapping: dict[str, list[GraphEdge]] - ) -> bool: + def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: """ is node2 after node1 """ if node1_id not in edge_mapping: return False - + for graph_edge in edge_mapping[node1_id]: if graph_edge.target_node_id == node2_id: return True - + if cls._is_node2_after_node1( - node1_id=graph_edge.target_node_id, - node2_id=node2_id, - edge_mapping=edge_mapping + node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping ): return True - - return False \ No newline at end of file + + return False diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index c7d484ddf5..afc09bfac5 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -10,7 +10,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRoute class GraphRuntimeState(BaseModel): variable_pool: VariablePool = Field(..., description="variable pool") """variable pool""" - + start_at: float = Field(..., description="start time") """start time""" total_tokens: int = 0 diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py index 0362343568..eedce8842b 100644 --- a/api/core/workflow/graph_engine/entities/run_condition.py +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -18,4 +18,4 @@ class RunCondition(BaseModel): @property def hash(self) -> str: - return hashlib.sha256(self.model_dump_json().encode()).hexdigest() \ No newline at end of file + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index b5d6e4c09d..bb24b51112 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -51,7 +51,7 @@ class RouteNodeState(BaseModel): :param run_result: run result """ - if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -68,13 +68,11 @@ class RouteNodeState(BaseModel): class RuntimeRouteState(BaseModel): routes: dict[str, list[str]] = Field( - default_factory=dict, - description="graph state routes (source_node_state_id: target_node_state_id)" + default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" ) node_state_mapping: dict[str, RouteNodeState] = Field( - default_factory=dict, - description="node state mapping (route_node_state_id: route_node_state)" + default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" ) def create_node_state(self, node_id: str) -> RouteNodeState: @@ -99,13 +97,13 @@ class RuntimeRouteState(BaseModel): self.routes[source_node_state_id].append(target_node_state_id) - def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ - -> list[RouteNodeState]: + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]: """ Get routes with node state by source node id :param source_node_state_id: source node state id :return: routes with node state """ - return [self.node_state_mapping[target_state_id] - for target_state_id in self.routes.get(source_node_state_id, [])] + return [ + self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) + ] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 65d9ab8446..1db9b690ab 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -8,7 +8,7 @@ from typing import Any, Optional from flask import Flask, current_app -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import ( NodeRunMetadataKey, @@ -48,8 +48,9 @@ logger = logging.getLogger(__name__) class GraphEngineThreadPool(ThreadPoolExecutor): - def __init__(self, max_workers=None, thread_name_prefix='', - initializer=None, initargs=(), max_submit_count=100) -> None: + def __init__( + self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100 + ) -> None: super().__init__(max_workers, thread_name_prefix, initializer, initargs) self.max_submit_count = max_submit_count self.submit_count = 0 @@ -57,9 +58,9 @@ class GraphEngineThreadPool(ThreadPoolExecutor): def submit(self, fn, *args, **kwargs): self.submit_count += 1 self.check_is_full() - + return super().submit(fn, *args, **kwargs) - + def check_is_full(self) -> None: print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") if self.submit_count > self.max_submit_count: @@ -70,35 +71,37 @@ class GraphEngine: workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} def __init__( - self, - tenant_id: str, - app_id: str, - workflow_type: WorkflowType, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - graph: Graph, - graph_config: Mapping[str, Any], - variable_pool: VariablePool, - max_execution_steps: int, - max_execution_time: int, - thread_pool_id: Optional[str] = None + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None, ) -> None: thread_pool_max_submit_count = 100 thread_pool_max_workers = 10 - ## init thread pool + # init thread pool if thread_pool_id: - if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + if thread_pool_id not in GraphEngine.workflow_thread_pool_mapping: raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") - + self.thread_pool_id = thread_pool_id self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] self.is_main_thread_pool = False else: - self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) + self.thread_pool = GraphEngineThreadPool( + max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count + ) self.thread_pool_id = str(uuid.uuid4()) self.is_main_thread_pool = True GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool @@ -113,13 +116,10 @@ class GraphEngine: user_id=user_id, user_from=user_from, invoke_from=invoke_from, - call_depth=call_depth + call_depth=call_depth, ) - self.graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter() - ) + self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.max_execution_steps = max_execution_steps self.max_execution_time = max_execution_time @@ -136,37 +136,40 @@ class GraphEngine: stream_processor_cls = EndStreamProcessor stream_processor = stream_processor_cls( - graph=self.graph, - variable_pool=self.graph_runtime_state.variable_pool + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool ) # run graph - generator = stream_processor.process( - self._run(start_node_id=self.graph.root_node_id) - ) + generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) for item in generator: try: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") return elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: - self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs - if item.route_node_state.node_run_result - and item.route_node_state.node_run_result.outputs - else {}) + self.graph_runtime_state.outputs = ( + item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {} + ) elif item.node_type == NodeType.ANSWER: if "answer" not in self.graph_runtime_state.outputs: self.graph_runtime_state.outputs["answer"] = "" - self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") - if item.route_node_state.node_run_result - and item.route_node_state.node_run_result.outputs - else "") - - self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() + self.graph_runtime_state.outputs["answer"] += "\n" + ( + item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "" + ) + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[ + "answer" + ].strip() except Exception as e: logger.exception(f"Graph run failed: {str(e)}") yield GraphRunFailedEvent(error=str(e)) @@ -186,12 +189,12 @@ class GraphEngine: del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] def _run( - self, - start_node_id: str, - in_parallel_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None - ) -> Generator[GraphEngineEvent, None, None]: + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None if in_parallel_id: parallel_start_node_id = start_node_id @@ -201,31 +204,28 @@ class GraphEngine: while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: - raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) + raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) # or max execution time reached if self._is_timed_out( - start_at=self.graph_runtime_state.start_at, - max_execution_time=self.max_execution_time + start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time ): - raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) # init route node state - route_node_state = self.graph_runtime_state.node_run_state.create_node_state( - node_id=next_node_id - ) + route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) # get node config node_id = route_node_state.node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} config not found.') + raise GraphRunFailedError(f"Node {node_id} config not found.") # convert to specific node - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) if not node_cls: - raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') + raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.") previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None @@ -237,7 +237,7 @@ class GraphEngine: graph=self.graph, graph_runtime_state=self.graph_runtime_state, previous_node_id=previous_node_id, - thread_pool_id=self.thread_pool_id + thread_pool_id=self.thread_pool_id, ) try: @@ -248,7 +248,7 @@ class GraphEngine: parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) for item in generator: @@ -263,8 +263,7 @@ class GraphEngine: # append route if previous_route_node_state: self.graph_runtime_state.node_run_state.add_route( - source_node_state_id=previous_route_node_state.id, - target_node_state_id=route_node_state.id + source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id ) except Exception as e: route_node_state.status = RouteNodeState.Status.FAILED @@ -279,13 +278,15 @@ class GraphEngine: parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) raise e # It may not be necessary, but it is necessary. :) - if (self.graph.node_id_config_mapping[next_node_id] - .get("data", {}).get("type", "").lower() == NodeType.END.value): + if ( + self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() + == NodeType.END.value + ): break previous_route_node_state = route_node_state @@ -305,7 +306,7 @@ class GraphEngine: run_condition=edge.run_condition, ).check( graph_runtime_state=self.graph_runtime_state, - previous_route_node_state=previous_route_node_state + previous_route_node_state=previous_route_node_state, ) if not result: @@ -343,14 +344,14 @@ class GraphEngine: if not result: continue - + if len(sub_edge_mappings) == 1: final_node_id = edge.target_node_id else: parallel_generator = self._run_parallel_branches( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, ) for item in parallel_generator: @@ -369,7 +370,7 @@ class GraphEngine: parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, ) for item in parallel_generator: @@ -383,14 +384,14 @@ class GraphEngine: next_node_id = final_node_id - if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id: break def _run_parallel_branches( - self, - edge_mappings: list[GraphEdge], - in_parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) @@ -398,14 +399,18 @@ class GraphEngine: node_id = edge_mappings[0].target_node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') + raise GraphRunFailedError( + f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches." + ) - node_title = node_config.get('data', {}).get('title') - raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') + node_title = node_config.get("data", {}).get("title") + raise GraphRunFailedError( + f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches." + ) parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: - raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + raise GraphRunFailedError(f"Parallel {parallel_id} not found.") # run parallel nodes, run in new thread and use queue to get results q: queue.Queue = queue.Queue() @@ -417,19 +422,22 @@ class GraphEngine: for edge in edge_mappings: if ( edge.target_node_id not in self.graph.node_parallel_mapping - or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id + or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id ): continue futures.append( - self.thread_pool.submit(self._run_parallel_node, **{ - 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] - 'q': q, - 'parallel_id': parallel_id, - 'parallel_start_node_id': edge.target_node_id, - 'parent_parallel_id': in_parallel_id, - 'parent_parallel_start_node_id': parallel_start_node_id, - }) + self.thread_pool.submit( + self._run_parallel_node, + **{ + "flask_app": current_app._get_current_object(), # type: ignore[attr-defined] + "q": q, + "parallel_id": parallel_id, + "parallel_start_node_id": edge.target_node_id, + "parent_parallel_id": in_parallel_id, + "parent_parallel_start_node_id": parallel_start_node_id, + }, + ) ) succeeded_count = 0 @@ -451,7 +459,7 @@ class GraphEngine: raise GraphRunFailedError(event.error) except queue.Empty: continue - + # wait all threads wait(futures) @@ -461,72 +469,80 @@ class GraphEngine: yield final_node_id def _run_parallel_node( - self, - flask_app: Flask, - q: queue.Queue, - parallel_id: str, - parallel_start_node_id: str, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, ) -> None: """ Run parallel nodes """ with flask_app.app_context(): try: - q.put(ParallelBranchRunStartedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id - )) + q.put( + ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) # run node generator = self._run( start_node_id=parallel_start_node_id, in_parallel_id=parallel_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) for item in generator: q.put(item) # trigger graph run success event - q.put(ParallelBranchRunSucceededEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id - )) + q.put( + ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) except GraphRunFailedError as e: - q.put(ParallelBranchRunFailedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - error=e.error - )) + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error, + ) + ) except Exception as e: logger.exception("Unknown Error when generating in parallel") - q.put(ParallelBranchRunFailedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - error=str(e) - )) + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e), + ) + ) finally: db.session.remove() def _run_node( - self, - node_instance: BaseNode, - route_node_state: RouteNodeState, - parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, ) -> Generator[GraphEngineEvent, None, None]: """ Run node @@ -542,7 +558,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) db.session.close() @@ -567,7 +583,7 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( - error=route_node_state.failed_reason or 'Unknown error.', + error=route_node_state.failed_reason or "Unknown error.", id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, @@ -576,7 +592,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): @@ -596,7 +612,7 @@ class GraphEngine: self._append_variables_recursively( node_id=node_instance.node_id, variable_key_list=[variable_key], - variable_value=variable_value + variable_value=variable_value, ) # add parallel info to run result metadata @@ -608,7 +624,9 @@ class GraphEngine: run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id if parent_parallel_id and parent_parallel_start_node_id: run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) yield NodeRunSucceededEvent( id=node_instance.id, @@ -619,7 +637,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) break @@ -635,7 +653,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -649,9 +667,9 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: # trigger node run failed event route_node_state.status = RouteNodeState.Status.FAILED route_node_state.failed_reason = "Workflow stopped." @@ -665,7 +683,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) return except Exception as e: @@ -674,10 +692,7 @@ class GraphEngine: finally: db.session.close() - def _append_variables_recursively(self, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): + def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ Append variables recursively :param node_id: node id @@ -685,10 +700,7 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add( - [node_id] + variable_key_list, - variable_value - ) + self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) # if variable_value is a dict, then recursively append variables if isinstance(variable_value, dict): @@ -696,9 +708,7 @@ class GraphEngine: # construct new key list new_key_list = variable_key_list + [key] self._append_variables_recursively( - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value + node_id=node_id, variable_key_list=new_key_list, variable_value=value ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 8cf01727ec..deacbbbbb0 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -29,14 +29,12 @@ class AnswerNode(BaseNode): # generate routes generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) - answer = '' + answer = "" for part in generate_routes: if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = self.graph_runtime_state.variable_pool.get( - value_selector - ) + value = self.graph_runtime_state.variable_pool.get(value_selector) if value: answer += value.markdown @@ -44,19 +42,11 @@ class AnswerNode(BaseNode): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "answer": answer - } - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer}) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -73,6 +63,6 @@ class AnswerNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 6cb80091c9..5e6de8fb15 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -1,4 +1,3 @@ - from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import ( @@ -12,12 +11,12 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class AnswerStreamGeneratorRouter: - @classmethod - def init(cls, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] - ) -> AnswerStreamGenerateRoute: + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: """ Get stream generate routes. :return: @@ -25,7 +24,7 @@ class AnswerStreamGeneratorRouter: # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} for answer_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + if node_config.get("data", {}).get("type") != NodeType.ANSWER.value: continue # get generate route for stream output @@ -37,12 +36,11 @@ class AnswerStreamGeneratorRouter: answer_dependencies = cls._fetch_answers_dependencies( answer_node_ids=answer_node_ids, reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping + node_id_config_mapping=node_id_config_mapping, ) return AnswerStreamGenerateRoute( - answer_generate_route=answer_generate_route, - answer_dependencies=answer_dependencies + answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies ) @classmethod @@ -56,8 +54,7 @@ class AnswerStreamGeneratorRouter: variable_selectors = variable_template_parser.extract_variable_selectors() value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors + variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors } variable_keys = list(value_selector_mapping.keys()) @@ -71,21 +68,17 @@ class AnswerStreamGeneratorRouter: template = node_data.answer for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω") generate_routes: list[GenerateRouteChunk] = [] - for part in template.split('Ω'): + for part in template.split("Ω"): if part: if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "") value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) + generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector)) else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) + generate_routes.append(TextGenerateRouteChunk(text=part)) return generate_routes @@ -101,15 +94,16 @@ class AnswerStreamGeneratorRouter: @classmethod def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys + cleaned_part = part.replace("{{", "").replace("}}", "") + return part.startswith("{{") and cleaned_part in variable_keys @classmethod - def _fetch_answers_dependencies(cls, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict] - ) -> dict[str, list[str]]: + def _fetch_answers_dependencies( + cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: """ Fetch answer dependencies :param answer_node_ids: answer node ids @@ -127,19 +121,20 @@ class AnswerStreamGeneratorRouter: answer_node_id=answer_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies + answer_dependencies=answer_dependencies, ) return answer_dependencies @classmethod - def _recursive_fetch_answer_dependencies(cls, - current_node_id: str, - answer_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]] - ) -> None: + def _recursive_fetch_answer_dependencies( + cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]], + ) -> None: """ Recursive fetch answer dependencies :param current_node_id: current node id @@ -152,12 +147,12 @@ class AnswerStreamGeneratorRouter: reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id - source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') - if source_node_type in ( - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER, - ): + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") + if source_node_type in { + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER.value, + }: answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( @@ -165,5 +160,5 @@ class AnswerStreamGeneratorRouter: answer_node_id=answer_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies + answer_dependencies=answer_dependencies, ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index c2a5dd5163..32dbf436ec 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -18,7 +18,6 @@ logger = logging.getLogger(__name__) class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes @@ -27,9 +26,7 @@ class AnswerStreamProcessor(StreamProcessor): self.route_position[answer_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: if isinstance(event, NodeRunStartedEvent): if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: @@ -47,9 +44,9 @@ class AnswerStreamProcessor(StreamProcessor): ] else: stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) - self.current_stream_chunk_generating_node_ids[ - event.route_node_state.node_id - ] = stream_out_answer_node_ids + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_answer_node_ids + ) for _ in stream_out_answer_node_ids: yield event @@ -77,9 +74,9 @@ class AnswerStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: """ Generate stream outputs. :param event: node run succeeded event @@ -87,10 +84,13 @@ class AnswerStreamProcessor(StreamProcessor): """ for answer_node_id, position in self.route_position.items(): # all depends on answer node id not in rest node ids - if (event.route_node_state.node_id != answer_node_id - and (answer_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): + if event.route_node_state.node_id != answer_node_id and ( + answer_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ) + ): continue route_position = self.route_position[answer_node_id] @@ -108,6 +108,7 @@ class AnswerStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + from_variable_selector=[answer_node_id, "answer"], ) else: route_chunk = cast(VarGenerateRouteChunk, route_chunk) @@ -115,9 +116,7 @@ class AnswerStreamProcessor(StreamProcessor): if not value_selector: break - value = self.variable_pool.get( - value_selector - ) + value = self.variable_pool.get(value_selector) if value is None: break @@ -158,8 +157,9 @@ class AnswerStreamProcessor(StreamProcessor): continue # all depends on answer node id not in rest node ids - if all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + if all( + dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ): if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): continue @@ -213,7 +213,7 @@ class AnswerStreamProcessor(StreamProcessor): return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index cbabbca37d..36c3fe180a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -7,16 +7,13 @@ from core.workflow.graph_engine.entities.graph import Graph class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.graph = graph self.variable_pool = variable_pool self.rest_node_ids = graph.node_ids.copy() @abstractmethod - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: @@ -35,9 +32,11 @@ class StreamProcessor(ABC): reachable_node_ids = [] unreachable_first_node_ids = [] for edge in self.graph.edge_mapping[finished_node_id]: - if (edge.run_condition - and edge.run_condition.branch_identify - and run_result.edge_source_handle == edge.run_condition.branch_identify): + if ( + edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify + ): reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) continue else: diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 620c2c426b..e356e7fd70 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ + answer: str = Field(..., description="answer template string") @@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR """generate route chunk type""" value_selector: list[str] = Field(..., description="value selector") @@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT """generate route chunk type""" text: str = Field(..., description="text") @@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel): """ AnswerStreamGenerateRoute entity """ + answer_dependencies: dict[str, list[str]] = Field( - ..., - description="answer dependencies (answer node id -> dependent answer node ids)" + ..., description="answer dependencies (answer node id -> dependent answer node ids)" ) answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., - description="answer generate route (answer node id -> generate route chunks)" + ..., description="answer generate route (answer node id -> generate route chunks)" ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index b9912314f1..7bfe45a13c 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -15,14 +15,16 @@ class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - def __init__(self, - id: str, - config: Mapping[str, Any], - graph_init_params: GraphInitParams, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None) -> None: + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + ) -> None: self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id @@ -46,8 +48,7 @@ class BaseNode(ABC): self.node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod - def _run(self) \ - -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: """ Run node :return: @@ -62,14 +63,14 @@ class BaseNode(ABC): result = self._run() if isinstance(result, NodeRunResult): - yield RunCompletedEvent( - run_result=result - ) + yield RunCompletedEvent(run_result=result) else: yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: + def extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], config: dict + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping :param graph_config: graph config @@ -82,17 +83,12 @@ class BaseNode(ABC): node_data = cls._node_data_cls(**config.get("data", {})) return cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data + graph_config=graph_config, node_id=node_id, node_data=node_data ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: BaseNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 955afdfa1d..9da7ad99f3 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, Union, cast from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider @@ -25,11 +25,10 @@ class CodeNode(BaseNode): """ code_language = CodeLanguage.PYTHON3 if filters: - code_language = (filters.get("code_language", CodeLanguage.PYTHON3)) + code_language = filters.get("code_language", CodeLanguage.PYTHON3) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers - if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) return code_provider.get_default_config() @@ -62,18 +61,10 @@ class CodeNode(BaseNode): # Transform result result = self._transform_result(result, node_data.outputs) - except (CodeExecutionException, ValueError) as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + except (CodeExecutionError, ValueError) as e: + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=result - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) def _check_string(self, value: str, variable: str) -> str: """ @@ -83,16 +74,18 @@ class CodeNode(BaseNode): :return: """ if not isinstance(value, str): - if isinstance(value, type(None)): + if value is None: return None else: raise ValueError(f"Output variable `{variable}` must be a string") - - if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be' - f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters') - return value.replace('\x00', '') + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + raise ValueError( + f"The length of output variable `{variable}` must be" + f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + ) + + return value.replace("\x00", "") def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ @@ -102,26 +95,30 @@ class CodeNode(BaseNode): :return: """ if not isinstance(value, int | float): - if isinstance(value, type(None)): + if value is None: return None else: raise ValueError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range,' - f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') + raise ValueError( + f"Output variable `{variable}` is out of range," + f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + ) if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision,' - f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') + if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError( + f"Output variable `{variable}` has too high precision," + f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + ) return value - def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], - prefix: str = '', - depth: int = 1) -> dict: + def _transform_result( + self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 + ) -> dict: """ Transform result :param result: result @@ -139,183 +136,187 @@ class CodeNode(BaseNode): self._transform_result( result=output_value, output_schema=None, - prefix=f'{prefix}.{output_name}' if prefix else output_name, - depth=depth + 1 + prefix=f"{prefix}.{output_name}" if prefix else output_name, + depth=depth + 1, ) elif isinstance(output_value, int | float): self._check_number( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, str): self._check_string( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, list): first_element = output_value[0] if len(output_value) > 0 else None if first_element is not None: - if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value): + if isinstance(first_element, int | float) and all( + value is None or isinstance(value, int | float) for value in output_value + ): for i, value in enumerate(output_value): self._check_number( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value): + elif isinstance(first_element, str) and all( + value is None or isinstance(value, str) for value in output_value + ): for i, value in enumerate(output_value): self._check_string( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value): + elif isinstance(first_element, dict) and all( + value is None or isinstance(value, dict) for value in output_value + ): for i, value in enumerate(output_value): if value is not None: self._transform_result( result=value, output_schema=None, - prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", + depth=depth + 1, ) else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') - elif isinstance(output_value, type(None)): + raise ValueError( + f"Output {prefix}.{output_name} is not a valid array." + f" make sure all elements are of the same type." + ) + elif output_value is None: pass else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') - + raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + return result parameters_validated = {} for output_name, output_config in output_schema.items(): - dot = '.' if prefix else '' + dot = "." if prefix else "" if output_name not in result: - raise ValueError(f'Output {prefix}{dot}{output_name} is missing.') - - if output_config.type == 'object': + raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + + if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an object," + f" got {type(result.get(output_name))} instead." ) else: transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, - prefix=f'{prefix}.{output_name}', - depth=depth + 1 + prefix=f"{prefix}.{output_name}", + depth=depth + 1, ) - elif output_config.type == 'number': + elif output_config.type == "number": # check if number available transformed_result[output_name] = self._check_number( - value=result[output_name], - variable=f'{prefix}{dot}{output_name}' + value=result[output_name], variable=f"{prefix}{dot}{output_name}" ) - elif output_config.type == 'string': + elif output_config.type == "string": # check if string available transformed_result[output_name] = self._check_string( value=result[output_name], - variable=f'{prefix}{dot}{output_name}', + variable=f"{prefix}{dot}{output_name}", ) - elif output_config.type == 'array[number]': + elif output_config.type == "array[number]": # check if array of number available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_number( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[string]': + elif output_config.type == "array[string]": # check if array of string available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_string( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[object]': + elif output_config.type == "array[object]": # check if array of object available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) - + for i, value in enumerate(result[output_name]): if not isinstance(value, dict): - if isinstance(value, type(None)): + if value is None: pass else: raise ValueError( - f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + f"Output {prefix}{dot}{output_name}[{i}] is not an object," + f" got {type(value)} instead at index {i}." ) transformed_result[output_name] = [ - None if value is None else self._transform_result( + None + if value is None + else self._transform_result( result=value, output_schema=output_config.children, - prefix=f'{prefix}{dot}{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}{dot}{output_name}[{i}]", + depth=depth + 1, ) for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f'Output type {output_config.type} is not supported.') - + raise ValueError(f"Output type {output_config.type} is not supported.") + parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError('Not all output parameters are validated.') + raise ValueError("Not all output parameters are validated.") return transformed_result @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -325,5 +326,6 @@ class CodeNode(BaseNode): :return: """ return { - node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index c0701ecccd..5eb0e0f63f 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -11,9 +11,10 @@ class CodeNodeData(BaseNodeData): """ Code Node Data. """ + class Output(BaseModel): - type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] - children: Optional[dict[str, 'Output']] = None + type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + children: Optional[dict[str, "Output"]] = None class Dependency(BaseModel): name: str @@ -23,4 +24,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 552914b308..7b78d67be8 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -25,18 +25,11 @@ class EndNode(BaseNode): value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) outputs[variable_selector.variable] = value - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: EndNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 8390f6d81b..9a7d2ecde3 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -3,13 +3,13 @@ from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam class EndStreamGeneratorRouter: - @classmethod - def init(cls, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_parallel_mapping: dict[str, str] - ) -> EndStreamParam: + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str], + ) -> EndStreamParam: """ Get stream generate routes. :return: @@ -17,7 +17,7 @@ class EndStreamGeneratorRouter: # parse stream output node value selector of end nodes end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} for end_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.END.value: + if node_config.get("data", {}).get("type") != NodeType.END.value: continue # skip end node in parallel @@ -33,18 +33,18 @@ class EndStreamGeneratorRouter: end_dependencies = cls._fetch_ends_dependencies( end_node_ids=end_node_ids, reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping + node_id_config_mapping=node_id_config_mapping, ) return EndStreamParam( end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) @classmethod - def extract_stream_variable_selector_from_node_data(cls, - node_id_config_mapping: dict[str, dict], - node_data: EndNodeData) -> list[list[str]]: + def extract_stream_variable_selector_from_node_data( + cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData + ) -> list[list[str]]: """ Extract stream variable selector from node data :param node_id_config_mapping: node id config mapping @@ -59,21 +59,22 @@ class EndStreamGeneratorRouter: continue node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_id_config_mapping: + if node_id != "sys" and node_id in node_id_config_mapping: node = node_id_config_mapping[node_id] - node_type = node.get('data', {}).get('type') + node_type = node.get("data", {}).get("type") if ( variable_selector.value_selector not in value_selectors - and node_type == NodeType.LLM.value - and variable_selector.value_selector[1] == 'text' + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == "text" ): value_selectors.append(variable_selector.value_selector) return value_selectors @classmethod - def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ - -> list[list[str]]: + def _extract_stream_variable_selector( + cls, node_id_config_mapping: dict[str, dict], config: dict + ) -> list[list[str]]: """ Extract stream variable selector from node config :param node_id_config_mapping: node id config mapping @@ -84,11 +85,12 @@ class EndStreamGeneratorRouter: return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) @classmethod - def _fetch_ends_dependencies(cls, - end_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict] - ) -> dict[str, list[str]]: + def _fetch_ends_dependencies( + cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: """ Fetch end dependencies :param end_node_ids: end node ids @@ -106,20 +108,21 @@ class EndStreamGeneratorRouter: end_node_id=end_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) return end_dependencies @classmethod - def _recursive_fetch_end_dependencies(cls, - current_node_id: str, - end_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] - end_dependencies: dict[str, list[str]] - ) -> None: + def _recursive_fetch_end_dependencies( + cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]], + ) -> None: """ Recursive fetch end dependencies :param current_node_id: current node id @@ -132,11 +135,11 @@ class EndStreamGeneratorRouter: reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id - source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') - if source_node_type in ( - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER, - ): + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") + if source_node_type in { + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + }: end_dependencies[end_node_id].append(source_node_id) else: cls._recursive_fetch_end_dependencies( @@ -144,5 +147,5 @@ class EndStreamGeneratorRouter: end_node_id=end_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 4474c2a78a..0366d7965d 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -15,7 +15,6 @@ logger = logging.getLogger(__name__) class EndStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) self.end_stream_param = graph.end_stream_param @@ -26,9 +25,7 @@ class EndStreamProcessor(StreamProcessor): self.has_outputed = False self.outputed_node_ids = set() - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: if isinstance(event, NodeRunStartedEvent): if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: @@ -38,7 +35,7 @@ class EndStreamProcessor(StreamProcessor): elif isinstance(event, NodeRunStreamChunkEvent): if event.in_iteration_id: if self.has_outputed and event.node_id not in self.outputed_node_ids: - event.chunk_content = '\n' + event.chunk_content + event.chunk_content = "\n" + event.chunk_content self.outputed_node_ids.add(event.node_id) self.has_outputed = True @@ -51,13 +48,13 @@ class EndStreamProcessor(StreamProcessor): ] else: stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) - self.current_stream_chunk_generating_node_ids[ - event.route_node_state.node_id - ] = stream_out_end_node_ids + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_end_node_ids + ) if stream_out_end_node_ids: if self.has_outputed and event.node_id not in self.outputed_node_ids: - event.chunk_content = '\n' + event.chunk_content + event.chunk_content = "\n" + event.chunk_content self.outputed_node_ids.add(event.node_id) self.has_outputed = True @@ -86,9 +83,9 @@ class EndStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: """ Generate stream outputs. :param event: node run succeeded event @@ -96,10 +93,12 @@ class EndStreamProcessor(StreamProcessor): """ for end_node_id, position in self.route_position.items(): # all depends on end node id not in rest node ids - if (event.route_node_state.node_id != end_node_id - and (end_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): + if event.route_node_state.node_id != end_node_id and ( + end_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id] + ) + ): continue route_position = self.route_position[end_node_id] @@ -116,9 +115,7 @@ class EndStreamProcessor(StreamProcessor): if not value_selector: continue - value = self.variable_pool.get( - value_selector - ) + value = self.variable_pool.get(value_selector) if value is None: break @@ -128,7 +125,7 @@ class EndStreamProcessor(StreamProcessor): if text: current_node_id = value_selector[0] if self.has_outputed and current_node_id not in self.outputed_node_ids: - text = '\n' + text + text = "\n" + text self.outputed_node_ids.add(current_node_id) self.has_outputed = True @@ -165,8 +162,7 @@ class EndStreamProcessor(StreamProcessor): continue # all depends on end node id not in rest node ids - if all(dep_id not in self.rest_node_ids - for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]): if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): continue @@ -178,7 +174,7 @@ class EndStreamProcessor(StreamProcessor): break position += 1 - + if not value_selector: continue diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index a0edf7b579..c3270ac22a 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData): """ END Node Data. """ + outputs: list[VariableSelector] @@ -15,11 +16,10 @@ class EndStreamParam(BaseModel): """ EndStreamParam entity """ + end_dependencies: dict[str, list[str]] = Field( - ..., - description="end dependencies (end node id -> dependent node ids)" + ..., description="end dependencies (end node id -> dependent node ids)" ) end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., - description="end stream variable selector mapping (end node id -> stream variable selectors)" + ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" ) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index c066d469d8..66dd1f2dc6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -7,32 +7,32 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, 'basic', 'bearer', 'custom'] + type: Literal[None, "basic", "bearer", "custom"] api_key: Union[None, str] = None header: Union[None, str] = None class HttpRequestNodeAuthorization(BaseModel): - type: Literal['no-auth', 'api-key'] + type: Literal["no-auth", "api-key"] config: Optional[HttpRequestNodeAuthorizationConfig] = None - @field_validator('config', mode='before') + @field_validator("config", mode="before") @classmethod def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): """ Check config, if type is no-auth, config should be None, otherwise it should be a dict. """ - if values.data['type'] == 'no-auth': + if values.data["type"] == "no-auth": return None else: if not v or not isinstance(v, dict): - raise ValueError('config should be a dict') + raise ValueError("config should be a dict") return v class HttpRequestNodeBody(BaseModel): - type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"] data: Union[None, str] = None @@ -47,7 +47,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] + method: Literal["get", "post", "put", "patch", "delete", "head"] url: str authorization: HttpRequestNodeAuthorization headers: str diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index d16bff58bd..f8ab4e3132 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -6,8 +6,8 @@ from urllib.parse import urlencode import httpx -import core.helper.ssrf_proxy as ssrf_proxy from configs import dify_config +from core.helper import ssrf_proxy from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.entities import ( @@ -33,12 +33,12 @@ class HttpExecutorResponse: check if response is file """ content_type = self.get_content_type() - file_content_types = ['image', 'audio', 'video'] + file_content_types = ["image", "audio", "video"] return any(v in content_type for v in file_content_types) def get_content_type(self) -> str: - return self.headers.get('content-type', '') + return self.headers.get("content-type", "") def extract_file(self) -> tuple[str, bytes]: """ @@ -47,28 +47,28 @@ class HttpExecutorResponse: if self.is_file: return self.get_content_type(), self.body - return '', b'' + return "", b"" @property def content(self) -> str: if isinstance(self.response, httpx.Response): return self.response.text else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def body(self) -> bytes: if isinstance(self.response, httpx.Response): return self.response.content else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def status_code(self) -> int: if isinstance(self.response, httpx.Response): return self.response.status_code else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def size(self) -> int: @@ -77,11 +77,11 @@ class HttpExecutorResponse: @property def readable_size(self) -> str: if self.size < 1024: - return f'{self.size} bytes' + return f"{self.size} bytes" elif self.size < 1024 * 1024: - return f'{(self.size / 1024):.2f} KB' + return f"{(self.size / 1024):.2f} KB" else: - return f'{(self.size / 1024 / 1024):.2f} MB' + return f"{(self.size / 1024 / 1024):.2f} MB" class HttpExecutor: @@ -120,7 +120,7 @@ class HttpExecutor: """ check if body is json """ - if body and body.type == 'json' and body.data: + if body and body.type == "json" and body.data: try: json.loads(body.data) return True @@ -134,15 +134,15 @@ class HttpExecutor: """ Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` """ - kv_paris = convert_text.split('\n') + kv_paris = convert_text.split("\n") result = {} for kv in kv_paris: if not kv.strip(): continue - kv = kv.split(':', maxsplit=1) + kv = kv.split(":", maxsplit=1) if len(kv) == 1: - k, v = kv[0], '' + k, v = kv[0], "" else: k, v = kv result[k.strip()] = v @@ -166,31 +166,31 @@ class HttpExecutor: # check if it's a valid JSON is_valid_json = self._is_json_body(node_data.body) - body_data = node_data.body.data or '' + body_data = node_data.body.data or "" if body_data: body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - content_type_is_set = any(key.lower() == 'content-type' for key in self.headers) - if node_data.body.type == 'json' and not content_type_is_set: - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set: - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + content_type_is_set = any(key.lower() == "content-type" for key in self.headers) + if node_data.body.type == "json" and not content_type_is_set: + self.headers["Content-Type"] = "application/json" + elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: + self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: body = self._to_dict(body_data) - if node_data.body.type == 'form-data': - self.files = {k: ('', v) for k, v in body.items()} - random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f'----WebKitFormBoundary{random_str(16)}' + if node_data.body.type == "form-data": + self.files = {k: ("", v) for k, v in body.items()} + random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) + self.boundary = f"----WebKitFormBoundary{random_str(16)}" - self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}' + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ['json', 'raw-text']: + elif node_data.body.type in {"json", "raw-text"}: self.body = body_data - elif node_data.body.type == 'none': - self.body = '' + elif node_data.body.type == "none": + self.body = "" self.variable_selectors = ( server_url_variable_selectors @@ -202,23 +202,23 @@ class HttpExecutor: def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) headers = deepcopy(self.headers) or {} - if self.authorization.type == 'api-key': + if self.authorization.type == "api-key": if self.authorization.config is None: - raise ValueError('self.authorization config is required') + raise ValueError("self.authorization config is required") if authorization.config is None: - raise ValueError('authorization config is required') + raise ValueError("authorization config is required") if self.authorization.config.api_key is None: - raise ValueError('api_key is required') + raise ValueError("api_key is required") if not authorization.config.header: - authorization.config.header = 'Authorization' + authorization.config.header = "Authorization" - if self.authorization.config.type == 'bearer': - headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' - elif self.authorization.config.type == 'basic': - headers[authorization.config.header] = f'Basic {authorization.config.api_key}' - elif self.authorization.config.type == 'custom': + if self.authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.authorization.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key return headers @@ -230,10 +230,13 @@ class HttpExecutor: if isinstance(response, httpx.Response): executor_response = HttpExecutorResponse(response) else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") - threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) if executor_response.size > threshold_size: raise ValueError( f'{"File" if executor_response.is_file else "Text"} size is too large,' @@ -248,17 +251,17 @@ class HttpExecutor: do http request depending on api bundle """ kwargs = { - 'url': self.server_url, - 'headers': headers, - 'params': self.params, - 'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write), - 'follow_redirects': True, + "url": self.server_url, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, } - if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'): + if self.method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") return response def invoke(self) -> HttpExecutorResponse: @@ -280,15 +283,15 @@ class HttpExecutor: """ server_url = self.server_url if self.params: - server_url += f'?{urlencode(self.params)}' + server_url += f"?{urlencode(self.params)}" - raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" headers = self._assembling_headers() for k, v in headers.items(): # get authorization header - if self.authorization.type == 'api-key': - authorization_header = 'Authorization' + if self.authorization.type == "api-key": + authorization_header = "Authorization" if self.authorization.config and self.authorization.config.header: authorization_header = self.authorization.config.header @@ -296,21 +299,21 @@ class HttpExecutor: raw_request += f'{k}: {"*" * len(v)}\n' continue - raw_request += f'{k}: {v}\n' + raw_request += f"{k}: {v}\n" - raw_request += '\n' + raw_request += "\n" # if files, use multipart/form-data with boundary if self.files: boundary = self.boundary - raw_request += f'--{boundary}' + raw_request += f"--{boundary}" for k, v in self.files.items(): raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f'{v[1]}\n' - raw_request += f'--{boundary}' - raw_request += '--' + raw_request += f"{v[1]}\n" + raw_request += f"--{boundary}" + raw_request += "--" else: - raw_request += self.body or '' + raw_request += self.body or "" return raw_request @@ -328,9 +331,9 @@ class HttpExecutor: for variable_selector in variable_selectors: variable = variable_pool.get_any(variable_selector.value_selector) if variable is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace('\n', '\\n') + value = variable.replace('"', '\\"').replace("\n", "\\n") else: value = variable variable_value_mapping[variable_selector.variable] = value diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 3f68c8b1d0..cd40819126 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -31,18 +31,18 @@ class HttpRequestNode(BaseNode): @classmethod def get_default_config(cls, filters: dict | None = None) -> dict: return { - 'type': 'http-request', - 'config': { - 'method': 'get', - 'authorization': { - 'type': 'no-auth', + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", }, - 'body': {'type': 'none'}, - 'timeout': { + "body": {"type": "none"}, + "timeout": { **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - 'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - 'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, } @@ -52,9 +52,8 @@ class HttpRequestNode(BaseNode): # TODO: Switch to use segment directly if node_data.authorization.config and node_data.authorization.config.api_key: node_data.authorization.config.api_key = parser.convert_template( - template=node_data.authorization.config.api_key, - variable_pool=self.graph_runtime_state.variable_pool - ).text + template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool + ).text # init http executor http_executor = None @@ -62,7 +61,7 @@ class HttpRequestNode(BaseNode): http_executor = HttpExecutor( node_data=node_data, timeout=self._get_request_timeout(node_data), - variable_pool=self.graph_runtime_state.variable_pool + variable_pool=self.graph_runtime_state.variable_pool, ) # invoke http executor @@ -71,7 +70,7 @@ class HttpRequestNode(BaseNode): process_data = {} if http_executor: process_data = { - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), } return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -84,13 +83,13 @@ class HttpRequestNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ - 'status_code': response.status_code, - 'body': response.content if not files else '', - 'headers': response.headers, - 'files': files, + "status_code": response.status_code, + "body": response.content if not files else "", + "headers": response.headers, + "files": files, }, process_data={ - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), }, ) @@ -107,10 +106,7 @@ class HttpRequestNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -126,11 +122,11 @@ class HttpRequestNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: - logging.exception(f'Failed to extract variable selector to variable mapping: {e}') + logging.exception(f"Failed to extract variable selector to variable mapping: {e}") return {} def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: @@ -144,7 +140,7 @@ class HttpRequestNode(BaseNode): # extract filename from url filename = path.basename(url) # extract extension if possible - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 338277ace1..54c1081fd3 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -15,6 +15,7 @@ class IfElseNodeData(BaseNodeData): """ Case entity representing a single logical condition group """ + case_id: str logical_operator: Literal["and", "or"] conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index ca87eecd0d..37384202d8 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -20,13 +20,9 @@ class IfElseNode(BaseNode): node_data = self.node_data node_data = cast(IfElseNodeData, node_data) - node_inputs: dict[str, list] = { - "conditions": [] - } + node_inputs: dict[str, list] = {"conditions": []} - process_datas: dict[str, list] = { - "condition_results": [] - } + process_datas: dict[str, list] = {"condition_results": []} input_conditions = [] final_result = False @@ -37,8 +33,7 @@ class IfElseNode(BaseNode): if node_data.cases: for case in node_data.cases: input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions + variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions ) # Apply the logical operator for the current case @@ -60,8 +55,7 @@ class IfElseNode(BaseNode): else: # Fallback to old structure if cases are not defined input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=node_data.conditions + variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) @@ -69,21 +63,14 @@ class IfElseNode(BaseNode): selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( - { - "group": "default", - "results": group_result, - "final_result": final_result - } + {"group": "default", "results": group_result, "final_result": final_result} ) node_inputs["conditions"] = input_conditions except Exception as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=node_inputs, - process_data=process_datas, - error=str(e) + status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e) ) outputs = {"result": final_result, "selected_case_id": selected_case_id} @@ -92,18 +79,15 @@ class IfElseNode(BaseNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, process_data=process_datas, - edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' - outputs=outputs + edge_source_handle=selected_case_id or "false", # Use case ID or 'default' + outputs=outputs, ) return data @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 5fc5a827ae..3c2c189159 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector + + parent_loop_id: Optional[str] = None # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector class IterationStartNodeData(BaseNodeData): """ Iteration Start Node Data. """ + pass + class IterationState(BaseIterationState): """ Iteration State. """ + outputs: list[Any] = None current_output: Optional[Any] = None @@ -29,6 +33,7 @@ class IterationState(BaseIterationState): """ Data. """ + iterator_length: int def get_last_output(self) -> Optional[Any]: @@ -38,9 +43,9 @@ class IterationState(BaseIterationState): if self.outputs: return self.outputs[-1] return None - + def get_current_output(self) -> Optional[Any]: """ Get current output. """ - return self.current_output \ No newline at end of file + return self.current_output diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 93eff16c33..4d944e93db 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -16,6 +16,7 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph @@ -33,6 +34,7 @@ class IterationNode(BaseNode): """ Iteration Node. """ + _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION @@ -45,31 +47,26 @@ class IterationNode(BaseNode): if not iterator_list_segment: raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") - + iterator_list_value = iterator_list_segment.to_object() if not isinstance(iterator_list_value, list): raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - inputs = { - "iterator_selector": iterator_list_value - } + inputs = {"iterator_selector": iterator_list_value} graph_config = self.graph_config - + if not self.node_data.start_node_id: - raise ValueError(f'field start_node_id in iteration {self.node_id} not found') + raise ValueError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self.node_data.start_node_id # init graph - iteration_graph = Graph.init( - graph_config=graph_config, - root_node_id=root_node_id - ) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) if not iteration_graph: - raise ValueError('iteration graph not found') + raise ValueError("iteration graph not found") leaf_node_ids = iteration_graph.get_leaf_node_ids() iteration_leaf_node_ids = [] @@ -97,26 +94,21 @@ class IterationNode(BaseNode): Condition( variable_selector=[self.node_id, "index"], comparison_operator="<", - value=str(len(iterator_list_value)) + value=str(len(iterator_list_value)), ) - ] - ) + ], + ), ) variable_pool = self.graph_runtime_state.variable_pool # append iteration variable (item, index) to variable pool - variable_pool.add( - [self.node_id, 'index'], - 0 - ) - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[0] - ) + variable_pool.add([self.node_id, "index"], 0) + variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine from core.workflow.graph_engine.graph_engine import GraphEngine + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -130,7 +122,7 @@ class IterationNode(BaseNode): graph_config=graph_config, variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, ) start_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -142,10 +134,8 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - metadata={ - "iterator_length": len(iterator_list_value) - }, - predecessor_node_id=self.previous_node_id + metadata={"iterator_length": len(iterator_list_value)}, + predecessor_node_id=self.previous_node_id, ) yield IterationRunNextEvent( @@ -154,7 +144,7 @@ class IterationNode(BaseNode): iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=0, - pre_iteration_output=None + pre_iteration_output=None, ) outputs: list[Any] = [] @@ -165,7 +155,11 @@ class IterationNode(BaseNode): if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: event.in_iteration_id = self.node_id - if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START: + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): continue if isinstance(event, NodeRunSucceededEvent): @@ -176,7 +170,9 @@ class IterationNode(BaseNode): if NodeRunMetadataKey.ITERATION_ID not in metadata: metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) event.route_node_state.node_run_result.metadata = metadata yield event @@ -192,21 +188,15 @@ class IterationNode(BaseNode): variable_pool.remove_node(node_id) # move to next iteration - current_index = variable_pool.get([self.node_id, 'index']) + current_index = variable_pool.get([self.node_id, "index"]) if current_index is None: - raise ValueError(f'iteration {self.node_id} current index not found') + raise ValueError(f"iteration {self.node_id} current index not found") next_index = int(current_index.to_object()) + 1 - variable_pool.add( - [self.node_id, 'index'], - next_index - ) + variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[next_index] - ) + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) yield IterationRunNextEvent( iteration_id=self.id, @@ -214,8 +204,9 @@ class IterationNode(BaseNode): iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=next_index, - pre_iteration_output=jsonable_encoder( - current_iteration_output) if current_iteration_output else None + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, ) elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): @@ -227,13 +218,9 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - }, + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, ) @@ -255,21 +242,14 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - } + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, ) yield RunCompletedEvent( run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'output': jsonable_encoder(outputs) - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} ) ) except Exception as e: @@ -282,16 +262,11 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - }, + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=str(e), ) - yield RunCompletedEvent( run_result=NodeRunResult( @@ -301,15 +276,12 @@ class IterationNode(BaseNode): ) finally: # remove iteration variable (item, index) from variable pool after iteration run completed - variable_pool.remove([self.node_id, 'index']) - variable_pool.remove([self.node_id, 'item']) - + variable_pool.remove([self.node_id, "index"]) + variable_pool.remove([self.node_id, "item"]) + @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -319,36 +291,33 @@ class IterationNode(BaseNode): :return: """ variable_mapping = { - f'{node_id}.input_selector': node_data.iterator_selector, + f"{node_id}.input_selector": node_data.iterator_selector, } # init graph - iteration_graph = Graph.init( - graph_config=graph_config, - root_node_id=node_data.start_node_id - ) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) if not iteration_graph: - raise ValueError('iteration graph not found') - + raise ValueError("iteration graph not found") + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): - if sub_node_config.get('data', {}).get('iteration_id') != node_id: + if sub_node_config.get("data", {}).get("iteration_id") != node_id: continue # variable selector to variable mapping try: # Get node class from core.workflow.nodes.node_mapping import node_classes - node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) + + node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) if not node_cls: continue node_cls = cast(BaseNode, node_cls) - + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - config=sub_node_config + graph_config=graph_config, config=sub_node_config ) sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) except NotImplementedError: @@ -356,7 +325,8 @@ class IterationNode(BaseNode): # remove iteration variables sub_node_variable_mapping = { - sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() + sub_node_id + "." + key: value + for key, value in sub_node_variable_mapping.items() if value[0] != node_id } @@ -364,8 +334,7 @@ class IterationNode(BaseNode): # remove variable out from iteration variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_graph.node_ids + key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids } - + return variable_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 25044cf3eb..88b9665ac6 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -11,6 +11,7 @@ class IterationStartNode(BaseNode): """ Iteration Start Node. """ + _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START @@ -18,16 +19,11 @@ class IterationStartNode(BaseNode): """ Run the node. """ - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED - ) - + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) + @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 7cf392277c..1cd88039b1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -9,6 +9,7 @@ class RerankingModelConfig(BaseModel): """ Reranking Model Config. """ + provider: str model: str @@ -17,6 +18,7 @@ class VectorSetting(BaseModel): """ Vector Setting. """ + vector_weight: float embedding_provider_name: str embedding_model_name: str @@ -26,6 +28,7 @@ class KeywordSetting(BaseModel): """ Keyword Setting. """ + keyword_weight: float @@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel): """ Weighted score Config. """ + vector_setting: VectorSetting keyword_setting: KeywordSetting @@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. """ + top_k: int score_threshold: Optional[float] = None - reranking_mode: str = 'reranking_model' + reranking_mode: str = "reranking_model" reranking_enable: bool = True reranking_model: Optional[RerankingModelConfig] = None weights: Optional[WeightedScoreConfig] = None + class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. """ + model: ModelConfig @@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'knowledge-retrieval' + + type: str = "knowledge-retrieval" query_variable_selector: list[str] dataset_ids: list[str] - retrieval_mode: Literal['single', 'multiple'] + retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2d1ac4731c..af55688a52 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,14 +24,11 @@ from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -45,62 +42,47 @@ class KnowledgeRetrievalNode(BaseNode): # extract variables variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) query = variable - variables = { - 'query': query - } + variables = {"query": query} if not query: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Query is required." + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) # retrieve knowledge try: - results = self._fetch_dataset_retriever( - node_data=node_data, query=query - ) - outputs = { - 'result': results - } + results = self._fetch_dataset_retriever(node_data=node_data, query=query) + outputs = {"result": results} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=None, - outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) except Exception as e: logger.exception("Error when running knowledge retrieval node") - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ - dict[str, Any]]: + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] dataset_ids = node_data.dataset_ids # Subquery: Count the number of available documents for each dataset - subquery = db.session.query( - Document.dataset_id, - func.count(Document.id).label('available_document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.dataset_id.in_(dataset_ids) - ).group_by(Document.dataset_id).having( - func.count(Document.id) > 0 - ).subquery() + subquery = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids), + ) + .group_by(Document.dataset_id) + .having(func.count(Document.id) > 0) + .subquery() + ) - results = db.session.query(Dataset).join( - subquery, Dataset.id == subquery.c.dataset_id - ).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id.in_(dataset_ids) - ).all() + results = ( + db.session.query(Dataset) + .join(subquery, Dataset.id == subquery.c.dataset_id) + .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .all() + ) for dataset in results: # pass if dataset is not available @@ -117,16 +99,14 @@ class KnowledgeRetrievalNode(BaseNode): model_type_instance = cast(LargeLanguageModel, model_type_instance) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if model_schema: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER all_documents = dataset_retrieval.single_retrieve( available_datasets=available_datasets, @@ -137,111 +117,108 @@ class KnowledgeRetrievalNode(BaseNode): query=query, model_config=model_config, model_instance=model_instance, - planning_strategy=planning_strategy + planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': + if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": reranking_model = { - 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider, - 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, } weights = None - elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': + elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { - 'vector_setting': { - "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight, - "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name, - "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name, + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, }, - 'keyword_setting': { + "keyword_setting": { "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - } + }, } else: reranking_model = None weights = None - all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, - self.user_from.value, - available_datasets, query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, - ) + all_documents = dataset_retrieval.multiple_retrieve( + self.app_id, + self.tenant_id, + self.user_id, + self.user_from.value, + available_datasets, + query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_mode, + reranking_model, + weights, + node_data.multiple_retrieval_config.reranking_enable, + ) context_list = [] if all_documents: document_score_list = {} page_number_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] - # both 'page' and 'score' are metadata fields - if item.metadata.get('page'): - page_number_list[item.metadata['doc_id']] = item.metadata['page'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() resource_number = 1 if dataset and document: source = { - 'metadata': { - '_source': 'knowledge', - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'document_data_source_type': document.data_source_type, - 'page': page_number_list.get(segment.index_node_id, None), - 'segment_id': segment.id, - 'retriever_from': 'workflow', - 'score': document_score_list.get(segment.index_node_id, None), - 'segment_hit_count': segment.hit_count, - 'segment_word_count': segment.word_count, - 'segment_position': segment.position, - 'segment_index_node_hash': segment.index_node_hash, + "metadata": { + "_source": "knowledge", + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "document_data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": "workflow", + "score": document_score_list.get(segment.index_node_id, None), + "segment_hit_count": segment.hit_count, + "segment_word_count": segment.word_count, + "segment_position": segment.position, + "segment_index_node_hash": segment.index_node_hash, }, - 'title': document.name + "title": document.name, } if segment.answer: - source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: - source['content'] = segment.get_sign_content() + source["content"] = segment.get_sign_content() context_list.append(source) resource_number += 1 return context_list @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: KnowledgeRetrievalNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -251,11 +228,12 @@ class KnowledgeRetrievalNode(BaseNode): :return: """ variable_mapping = {} - variable_mapping[node_id + '.query'] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data: KnowledgeRetrievalNodeData + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -266,10 +244,7 @@ class KnowledgeRetrievalNode(BaseNode): model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -280,8 +255,7 @@ class KnowledgeRetrievalNode(BaseNode): # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -297,19 +271,16 @@ class KnowledgeRetrievalNode(BaseNode): # model config completion_params = node_data.single_retrieval_config.model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 1e48a10bc7..93ee0ac250 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -11,6 +11,7 @@ class ModelConfig(BaseModel): """ Model Config. """ + provider: str name: str mode: str @@ -21,6 +22,7 @@ class ContextConfig(BaseModel): """ Context Config. """ + enabled: bool variable_selector: Optional[list[str]] = None @@ -29,37 +31,47 @@ class VisionConfig(BaseModel): """ Vision Config. """ + class Configs(BaseModel): """ Configs. """ - detail: Literal['low', 'high'] + + detail: Literal["low", "high"] enabled: bool configs: Optional[Configs] = None + class PromptConfig(BaseModel): """ Prompt Config. """ + jinja2_variables: Optional[list[VariableSelector]] = None + class LLMNodeChatModelMessage(ChatModelMessage): """ LLM Node Chat Model Message. """ + jinja2_text: Optional[str] = None + class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): """ LLM Node Chat Model Prompt Template. """ + jinja2_text: Optional[str] = None + class LLMNodeData(BaseNodeData): """ LLM Node Data. """ + model: ModelConfig prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] prompt_config: Optional[PromptConfig] = None diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index f26ec1b0b5..3d336b0b0b 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -45,11 +45,11 @@ if TYPE_CHECKING: from core.file.file_obj import FileVar - class ModelInvokeCompleted(BaseModel): """ Model invoke completed """ + text: str usage: LLMUsage finish_reason: Optional[str] = None @@ -89,7 +89,7 @@ class LLMNode(BaseNode): files = self._fetch_files(node_data, variable_pool) if files: - node_inputs['#files#'] = [file.to_dict() for file in files] + node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value generator = self._fetch_context(node_data, variable_pool) @@ -100,7 +100,7 @@ class LLMNode(BaseNode): yield event if context: - node_inputs['#context#'] = context # type: ignore + node_inputs["#context#"] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -111,24 +111,22 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) - if node_data.memory else None, + query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'model_provider': model_config.provider, - 'model_name': model_config.model, + "model_provider": model_config.provider, + "model_name": model_config.model, } # handle invoke result @@ -136,10 +134,10 @@ class LLMNode(BaseNode): node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, - stop=stop + stop=stop, ) - result_text = '' + result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: @@ -156,16 +154,12 @@ class LLMNode(BaseNode): status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, - process_data=process_data + process_data=process_data, ) ) return - outputs = { - 'text': result_text, - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} yield RunCompletedEvent( run_result=NodeRunResult( @@ -176,17 +170,19 @@ class LLMNode(BaseNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None) \ - -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -206,9 +202,7 @@ class LLMNode(BaseNode): ) # handle invoke result - generator = self._handle_invoke_result( - invoke_result=invoke_result - ) + generator = self._handle_invoke_result(invoke_result=invoke_result) usage = LLMUsage.empty_usage() for event in generator: @@ -219,8 +213,9 @@ class LLMNode(BaseNode): # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ - -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Handle invoke result :param invoke_result: invoke result @@ -231,17 +226,14 @@ class LLMNode(BaseNode): model = None prompt_messages: list[PromptMessage] = [] - full_text = '' + full_text = "" usage = None finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text - yield RunStreamChunkEvent( - chunk_content=text, - from_variable_selector=[self.node_id, 'text'] - ) + yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) if not model: model = result.model @@ -258,15 +250,11 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - yield ModelInvokeCompleted( - text=full_text, - usage=usage, - finish_reason=finish_reason - ) + yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) - def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + def _transform_chat_messages( + self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -275,13 +263,13 @@ class LLMNode(BaseNode): """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2' and messages.jinja2_text: + if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2' and message.jinja2_text: + if message.edition_type == "jinja2" and message.jinja2_text: message.text = message.jinja2_text return messages @@ -300,17 +288,15 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) def parse_dict(d: dict) -> str: """ Parse dict into string """ # check if it's a context structure - if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: - return d['content'] + if "metadata" in d and "_source" in d["metadata"] and "content" in d: + return d["content"] # else, parse the dict try: @@ -321,7 +307,7 @@ class LLMNode(BaseNode): if isinstance(value, str): value = value elif isinstance(value, list): - result = '' + result = "" for item in value: if isinstance(item, dict): result += parse_dict(item) @@ -331,7 +317,7 @@ class LLMNode(BaseNode): result += str(item) else: result += str(item) - result += '\n' + result += "\n" value = result.strip() elif isinstance(value, dict): value = parse_dict(value) @@ -366,18 +352,19 @@ class LLMNode(BaseNode): for variable_selector in variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value @@ -393,7 +380,7 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) + files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) if not files: return [] @@ -415,29 +402,25 @@ class LLMNode(BaseNode): context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - yield RunRetrieverResourceEvent( - retriever_resources=[], - context=context_value - ) + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) elif isinstance(context_value, list): - context_str = '' + context_str = "" original_retriever_resource = [] for item in context_value: if isinstance(item, str): - context_str += item + '\n' + context_str += item + "\n" else: - if 'content' not in item: - raise ValueError(f'Invalid context structure: {item}') + if "content" not in item: + raise ValueError(f"Invalid context structure: {item}") - context_str += item['content'] + '\n' + context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip() + retriever_resources=original_retriever_resource, context=context_str.strip() ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: @@ -446,34 +429,37 @@ class LLMNode(BaseNode): :param context_dict: context dict :return: """ - if ('metadata' in context_dict and '_source' in context_dict['metadata'] - and context_dict['metadata']['_source'] == 'knowledge'): - metadata = context_dict.get('metadata', {}) + if ( + "metadata" in context_dict + and "_source" in context_dict["metadata"] + and context_dict["metadata"]["_source"] == "knowledge" + ): + metadata = context_dict.get("metadata", {}) source = { - 'position': metadata.get('position'), - 'dataset_id': metadata.get('dataset_id'), - 'dataset_name': metadata.get('dataset_name'), - 'document_id': metadata.get('document_id'), - 'document_name': metadata.get('document_name'), - 'data_source_type': metadata.get('document_data_source_type'), - 'segment_id': metadata.get('segment_id'), - 'retriever_from': metadata.get('retriever_from'), - 'score': metadata.get('score'), - 'hit_count': metadata.get('segment_hit_count'), - 'word_count': metadata.get('segment_word_count'), - 'segment_position': metadata.get('segment_position'), - 'index_node_hash': metadata.get('segment_index_node_hash'), - 'content': context_dict.get('content'), - 'page': metadata.get('page'), + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("document_data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), } return source return None - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data_model: node data model @@ -484,10 +470,7 @@ class LLMNode(BaseNode): model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -498,8 +481,7 @@ class LLMNode(BaseNode): # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -515,19 +497,16 @@ class LLMNode(BaseNode): # model config completion_params = node_data_model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data_model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") @@ -543,9 +522,9 @@ class LLMNode(BaseNode): stop=stop, ) - def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], - variable_pool: VariablePool, - model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory( + self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + ) -> Optional[TokenBufferMemory]: """ Fetch memory :param node_data_memory: node data memory @@ -556,35 +535,35 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None # get conversation - conversation = db.session.query(Conversation).filter( - Conversation.app_id == self.app_id, - Conversation.id == conversation_id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + .first() + ) if not conversation: return None - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) return memory - def _fetch_prompt_messages(self, node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt_messages( + self, + node_data: LLMNodeData, + query: Optional[str], + query_prompt_template: Optional[str], + inputs: dict[str, str], + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt messages :param node_data: node data @@ -601,7 +580,7 @@ class LLMNode(BaseNode): prompt_messages = prompt_transform.get_prompt( prompt_template=node_data.prompt_template, inputs=inputs, - query=query if query else '', + query=query or "", files=files, context=context, memory_config=node_data.memory, @@ -621,8 +600,11 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( - content_item, ImagePromptMessageContent): + if ( + vision_enabled + and content_item.type == PromptMessageContentType.IMAGE + and isinstance(content_item, ImagePromptMessageContent) + ): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -632,15 +614,18 @@ class LLMNode(BaseNode): if len(prompt_message_content) > 1: prompt_message.content = prompt_message_content - elif (len(prompt_message_content) == 1 - and prompt_message_content[0].type == PromptMessageContentType.TEXT): + elif ( + len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT + ): prompt_message.content = prompt_message_content[0].data filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError("No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding.") + raise ValueError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) return filtered_prompt_messages, stop @@ -678,7 +663,7 @@ class LLMNode(BaseNode): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_instance.model: + if "gpt-4" in model_instance.model: used_quota = 20 else: used_quota = 1 @@ -689,16 +674,13 @@ class LLMNode(BaseNode): Provider.provider_name == model_instance.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -712,11 +694,11 @@ class LLMNode(BaseNode): variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type != 'jinja2': + if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) else: - if prompt_template.edition_type != 'jinja2': + if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() @@ -726,39 +708,38 @@ class LLMNode(BaseNode): memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector if node_data.context.enabled: - variable_mapping['#context#'] = node_data.context.variable_selector + variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] + variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type == 'jinja2': + if prompt.edition_type == "jinja2": enable_jinja = True break else: - if prompt_template.edition_type == 'jinja2': + if prompt_template.edition_type == "jinja2": enable_jinja = True if enable_jinja: for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping @@ -775,26 +756,19 @@ class LLMNode(BaseNode): "prompt_templates": { "chat_model": { "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant.", - "edition_type": "basic" - } + {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} ] }, "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic" + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic", }, - "stop": ["Human:"] - } + "stop": ["Human:"], + }, } - } + }, } diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 8a5684551e..a8a0debe64 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,3 @@ - from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState @@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData): Loop Node Data. """ + class LoopState(BaseIterationState): """ Loop State. - """ \ No newline at end of file + """ diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 526404e30d..fbc68b79cb 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -10,6 +10,7 @@ class LoopNode(BaseNode): """ Loop Node. """ + _node_data_cls = LoopNodeData _node_type = NodeType.LOOP @@ -21,14 +22,16 @@ class LoopNode(BaseNode): """ Get conditions. """ - node_id = node_config.get('id') + node_id = node_config.get("id") if not node_id: return [] # TODO waiting for implementation - return [Condition( - variable_selector=[node_id, 'index'], - comparison_operator="≤", - value_type="value_selector", - value_selector=[] - )] + return [ + Condition( + variable_selector=[node_id, "index"], + comparison_operator="≤", + value_type="value_selector", + value_selector=[], + ) + ] diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 7bb123b126..5697d7c049 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -8,47 +8,52 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str completion_params: dict[str, Any] = {} + class ParameterConfig(BaseModel): """ Parameter Config. """ + name: str - type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] options: Optional[list[str]] = None description: str required: bool - @field_validator('name', mode='before') + @field_validator("name", mode="before") @classmethod def validate_name(cls, value) -> str: if not value: - raise ValueError('Parameter name is required') - if value in ['__reason', '__is_success']: - raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + raise ValueError("Parameter name is required") + if value in {"__reason", "__is_success"}: + raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value + class ParameterExtractorNodeData(BaseNodeData): """ Parameter Extractor Node Data. """ + model: ModelConfig query: list[str] parameters: list[ParameterConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None - reasoning_mode: Literal['function_call', 'prompt'] + reasoning_mode: Literal["function_call", "prompt"] - @field_validator('reasoning_mode', mode='before') + @field_validator("reasoning_mode", mode="before") @classmethod def set_reasoning_mode(cls, v) -> str: - return v or 'function_call' + return v or "function_call" def get_parameter_json_schema(self) -> dict: """ @@ -56,32 +61,26 @@ class ParameterExtractorNodeData(BaseNodeData): :return: parameter json schema """ - parameters = { - 'type': 'object', - 'properties': {}, - 'required': [] - } + parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = { - 'description': parameter.description - } + parameter_schema = {"description": parameter.description} - if parameter.type in ['string', 'select']: - parameter_schema['type'] = 'string' - elif parameter.type.startswith('array'): - parameter_schema['type'] = 'array' + if parameter.type in {"string", "select"}: + parameter_schema["type"] = "string" + elif parameter.type.startswith("array"): + parameter_schema["type"] = "array" nested_type = parameter.type[6:-1] - parameter_schema['items'] = {'type': nested_type} + parameter_schema["items"] = {"type": nested_type} else: - parameter_schema['type'] = parameter.type + parameter_schema["type"] = parameter.type - if parameter.type == 'select': - parameter_schema['enum'] = parameter.options + if parameter.type == "select": + parameter_schema["enum"] = parameter.options + + parameters["properties"][parameter.name] = parameter_schema - parameters['properties'][parameter.name] = parameter_schema - if parameter.required: - parameters['required'].append(parameter.name) + parameters["required"].append(parameter.name) - return parameters \ No newline at end of file + return parameters diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2e65705f10..a6454bd1cd 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -45,6 +45,7 @@ class ParameterExtractorNode(LLMNode): """ Parameter Extractor Node. """ + _node_data_cls = ParameterExtractorNodeData _node_type = NodeType.PARAMETER_EXTRACTOR @@ -57,11 +58,8 @@ class ParameterExtractorNode(LLMNode): "model": { "prompt_templates": { "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "stop": ["Human:"] + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "stop": ["Human:"], } } } @@ -78,9 +76,9 @@ class ParameterExtractorNode(LLMNode): query = variable inputs = { - 'query': query, - 'parameters': jsonable_encoder(node_data.parameters), - 'instruction': jsonable_encoder(node_data.instruction), + "query": query, + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), } model_instance, model_config = self._fetch_model_config(node_data.model) @@ -95,30 +93,29 @@ class ParameterExtractorNode(LLMNode): # fetch memory memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ - and node_data.reasoning_mode == 'function_call': - # use function call + if ( + set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} + and node_data.reasoning_mode == "function_call" + ): + # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, - query, - self.graph_runtime_state.variable_pool, - model_config, - memory) + prompt_messages = self._generate_prompt_engineering_prompt( + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + ) prompt_message_tools = [] process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': None, - 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - 'tool_call': None, + "usage": None, + "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + "tool_call": None, } try: @@ -129,20 +126,17 @@ class ParameterExtractorNode(LLMNode): tools=prompt_message_tools, stop=model_config.stop, ) - process_data['usage'] = jsonable_encoder(usage) - process_data['tool_call'] = jsonable_encoder(tool_call) - process_data['llm_text'] = text + process_data["usage"] = jsonable_encoder(usage) + process_data["tool_call"] = jsonable_encoder(tool_call) + process_data["llm_text"] = text except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 0, - '__reason': str(e) - }, + outputs={"__is_success": 0, "__reason": str(e)}, error=str(e), - metadata={} + metadata={}, ) error = None @@ -167,24 +161,23 @@ class ParameterExtractorNode(LLMNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 1 if not error else 0, - '__reason': error, - **result - }, + outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str], + ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: """ Invoke large language model :param node_data_model: node data model @@ -217,32 +210,35 @@ class ParameterExtractorNode(LLMNode): return text, usage, tool_call - def _generate_function_call_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + def _generate_function_call_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps( - node_data.get_parameter_json_schema())) + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( + content=query, structure=json.dumps(node_data.get_parameter_json_schema()) + ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_function_calling_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -255,124 +251,125 @@ class ParameterExtractorNode(LLMNode): example_messages = [] for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: id = uuid.uuid4().hex - example_messages.extend([ - UserPromptMessage(content=example['user']['query']), - AssistantPromptMessage( - content=example['assistant']['text'], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example['assistant']['function_call']['name'], - arguments=json.dumps(example['assistant']['function_call']['parameters'] - ) - )) - ] - ), - ToolPromptMessage( - content='Great! You have called the function with the correct parameters.', - tool_call_id=id - ), - AssistantPromptMessage( - content='I have extracted the parameters, let\'s move on.', - ) - ]) + example_messages.extend( + [ + UserPromptMessage(content=example["user"]["query"]), + AssistantPromptMessage( + content=example["assistant"]["text"], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example["assistant"]["function_call"]["name"], + arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), + ), + ) + ], + ), + ToolPromptMessage( + content="Great! You have called the function with the correct parameters.", tool_call_id=id + ), + AssistantPromptMessage( + content="I have extracted the parameters, let's move on.", + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) # generate tool tool = PromptMessageTool( name=FUNCTION_CALLING_EXTRACTOR_NAME, - description='Extract parameters from the natural language text', + description="Extract parameters from the natural language text", parameters=node_data.get_parameter_json_schema(), ) return prompt_messages, [tool] - def _generate_prompt_engineering_prompt(self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_prompt( + self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate prompt engineering prompt. """ model_mode = ModelMode.value_of(data.model.mode) if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory) elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory) else: raise ValueError(f"Invalid model mode: {model_mode}") - def _generate_prompt_engineering_completion_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_completion_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_prompt_engineering_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, - inputs={ - 'structure': json.dumps(node_data.get_parameter_json_schema()) - }, - query='', + inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _generate_prompt_engineering_chat_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_chat_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") prompt_template = self._get_prompt_engineering_prompt_template( node_data, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), - text=query + structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, memory, rest_token + variable_pool, + memory, + rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -384,18 +381,23 @@ class ParameterExtractorNode(LLMNode): # add example messages before last user message example_messages = [] for example in CHAT_EXAMPLE: - example_messages.extend([ - UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example['user']['json']), - text=example['user']['query'], - )), - AssistantPromptMessage( - content=json.dumps(example['assistant']['json']), - ) - ]) + example_messages.extend( + [ + UserPromptMessage( + content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example["user"]["json"]), + text=example["user"]["query"], + ) + ), + AssistantPromptMessage( + content=json.dumps(example["assistant"]["json"]), + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) return prompt_messages @@ -410,28 +412,28 @@ class ParameterExtractorNode(LLMNode): if parameter.required and parameter.name not in result: raise ValueError(f"Parameter {parameter.name} is required") - if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: + if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: raise ValueError(f"Invalid `select` value for parameter {parameter.name}") - if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): + if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): raise ValueError(f"Invalid `number` value for parameter {parameter.name}") - if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): + if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") - if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): + if parameter.type == "string" and not isinstance(result.get(parameter.name), str): raise ValueError(f"Invalid `string` value for parameter {parameter.name}") - if parameter.type.startswith('array'): + if parameter.type.startswith("array"): if not isinstance(result.get(parameter.name), list): raise ValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in result.get(parameter.name): - if nested_type == 'number' and not isinstance(item, int | float): + if nested_type == "number" and not isinstance(item, int | float): raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == 'string' and not isinstance(item, str): + if nested_type == "string" and not isinstance(item, str): raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == 'object' and not isinstance(item, dict): + if nested_type == "object" and not isinstance(item, dict): raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result @@ -443,12 +445,12 @@ class ParameterExtractorNode(LLMNode): for parameter in data.parameters: if parameter.name in result: # transform value - if parameter.type == 'number': + if parameter.type == "number": if isinstance(result[parameter.name], int | float): transformed_result[parameter.name] = result[parameter.name] elif isinstance(result[parameter.name], str): try: - if '.' in result[parameter.name]: + if "." in result[parameter.name]: result[parameter.name] = float(result[parameter.name]) else: result[parameter.name] = int(result[parameter.name]) @@ -465,40 +467,40 @@ class ParameterExtractorNode(LLMNode): # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ['string', 'select']: + elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith('array'): + elif parameter.type.startswith("array"): if isinstance(result[parameter.name], list): nested_type = parameter.type[6:-1] transformed_result[parameter.name] = [] for item in result[parameter.name]: - if nested_type == 'number': + if nested_type == "number": if isinstance(item, int | float): transformed_result[parameter.name].append(item) elif isinstance(item, str): try: - if '.' in item: + if "." in item: transformed_result[parameter.name].append(float(item)) else: transformed_result[parameter.name].append(int(item)) except ValueError: pass - elif nested_type == 'string': + elif nested_type == "string": if isinstance(item, str): transformed_result[parameter.name].append(item) - elif nested_type == 'object': + elif nested_type == "object": if isinstance(item, dict): transformed_result[parameter.name].append(item) if parameter.name not in transformed_result: - if parameter.type == 'number': + if parameter.type == "number": transformed_result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ['string', 'select']: - transformed_result[parameter.name] = '' - elif parameter.type.startswith('array'): + elif parameter.type in {"string", "select"}: + transformed_result[parameter.name] = "" + elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] return transformed_result @@ -514,24 +516,24 @@ class ParameterExtractorNode(LLMNode): """ stack = [] for i, c in enumerate(text): - if c == '{' or c == '[': + if c in {"{", "["}: stack.append(c) - elif c == '}' or c == ']': + elif c in {"}", "]"}: # check if stack is empty if not stack: return text[:i] # check if the last element in stack is matching - if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): stack.pop() if not stack: - return text[:i + 1] + return text[: i + 1] else: return text[:i] return None # extract json from the text for idx in range(len(result)): - if result[idx] == '{' or result[idx] == '[': + if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: try: @@ -554,12 +556,12 @@ class ParameterExtractorNode(LLMNode): """ result = {} for parameter in data.parameters: - if parameter.type == 'number': + if parameter.type == "number": result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ['string', 'select']: - result[parameter.name] = '' + elif parameter.type in {"string", "select"}: + result[parameter.name] = "" return result @@ -575,71 +577,76 @@ class ParameterExtractorNode(LLMNode): return variable_template_parser.format(inputs) - def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: + def _get_function_calling_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: raise ValueError(f"Model mode {model_mode} not support.") - def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: - + def _get_prompt_engineering_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, - text=input_text, - instruction=instruction) - .replace('{γγγ', '') - .replace('}γγγ', '') + text=COMPLETION_GENERATE_JSON_PROMPT.format( + histories=memory_str, text=input_text, instruction=instruction + ) + .replace("{γγγ", "") + .replace("}γγγ", "") ) else: raise ValueError(f"Model mode {model_mode} not support.") - def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) model_instance, model_config = self._fetch_model_config(node_data.model) @@ -659,12 +666,12 @@ class ParameterExtractorNode(LLMNode): prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 @@ -673,26 +680,28 @@ class ParameterExtractorNode(LLMNode): model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - prompt_messages - ) + 1000 # add 1000 to ensure tool call messages + curr_message_tokens = ( + model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + ) # add 1000 to ensure tool call messages max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config. """ @@ -703,10 +712,7 @@ class ParameterExtractorNode(LLMNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -715,17 +721,13 @@ class ParameterExtractorNode(LLMNode): :param node_data: node data :return: """ - variable_mapping = { - 'query': node_data.query - } + variable_mapping = {"query": node_data.query} if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index 499c58d505..58fcecc53b 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,4 +1,4 @@ -FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' +FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. ### Task @@ -23,7 +23,7 @@ Steps: To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. ### Final Output Produce well-formatted function calls in json without XML tags, as shown in the example. -""" +""" # noqa: E501 FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. @@ -33,63 +33,52 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr \x7bstructure\x7d -""" +""" # noqa: E501 -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + }, }, + "required": ["location"], }, - 'required': ['location'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters." + " in this case, I need to call the function with the location parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'location': 'San Francisco' - } - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters." + " in this case, I need to call the function with the food parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'food': 'apple pie' - } - } - } -}] +] COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: Some extra information are provided below, I should always follow the instructions as possible as I can. @@ -130,7 +119,7 @@ Inside XML tags, there is a text that I should extract parameters ### Answer I should always output a valid JSON object. Output nothing other than the JSON object. ```JSON -""" +""" # noqa: E501 CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. The structure of the JSON object you can found in the instructions. @@ -161,46 +150,33 @@ Inside XML tags, there is a text that you should convert to a JSON """ -CHAT_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'json': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True - } +CHAT_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "json": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + } + }, + "required": ["location"], }, - 'required': ['location'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'location': 'San Francisco' - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'json': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "json": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'result': 'apple pie' - } - } -}] \ No newline at end of file +] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index c0b0a8b696..40f7ce7582 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -8,8 +8,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -20,6 +21,7 @@ class ClassConfig(BaseModel): """ Class Config. """ + id: str name: str @@ -28,8 +30,9 @@ class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ + query_variable_selector: list[str] - type: str = 'question-classifier' + type: str = "question-classifier" model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ecab8db9b6..2ae58bc5f7 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -45,34 +45,25 @@ class QuestionClassifierNode(LLMNode): # extract variables variable = variable_pool.get(node_data.query_variable_selector) query = variable.value if variable else None - variables = { - 'query': query - } + variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else '' + instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else "" node_data.instruction = instruction # fetch prompt messages prompt_messages, stop = self._fetch_prompt( - node_data=node_data, - context='', - query=query, - memory=memory, - model_config=model_config + node_data=node_data, context="", query=query, memory=memory, model_config=model_config ) # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop + node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) - result_text = '' + result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: @@ -87,8 +78,8 @@ class QuestionClassifierNode(LLMNode): try: result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) - if 'category_name' in result_text_json and 'category_id' in result_text_json: - category_id_result = result_text_json['category_id'] + if "category_name" in result_text_json and "category_id" in result_text_json: + category_id_result = result_text_json["category_id"] classes = node_data.classes classes_map = {class_.id: class_.name for class_ in classes} category_ids = [_class.id for _class in classes] @@ -100,17 +91,14 @@ class QuestionClassifierNode(LLMNode): logging.error(f"Failed to parse result text: {result_text}") try: process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } - outputs = { - 'class_name': category_name + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, } + outputs = {"class_name": category_name} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -121,9 +109,9 @@ class QuestionClassifierNode(LLMNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) except ValueError as e: @@ -134,17 +122,14 @@ class QuestionClassifierNode(LLMNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -153,7 +138,7 @@ class QuestionClassifierNode(LLMNode): :param node_data: node data :return: """ - variable_mapping = {'query': node_data.query_variable_selector} + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) @@ -161,10 +146,8 @@ class QuestionClassifierNode(LLMNode): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } - + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping @classmethod @@ -174,19 +157,16 @@ class QuestionClassifierNode(LLMNode): :param filters: filter by node config parameters. :return: """ - return { - "type": "question-classifier", - "config": { - "instructions": "" - } - } + return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt(self, node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt( + self, + node_data: QuestionClassifierNodeData, + query: str, + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt :param node_data: node data @@ -202,118 +182,122 @@ class QuestionClassifierNode(LLMNode): prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: QuestionClassifierNodeData, + query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + def _get_prompt_template( + self, + node_data: QuestionClassifierNodeData, + query: str, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: - category = { - 'category_id': class_.id, - 'category_name': class_.name - } + category = {"category_id": class_.id, "category_name": class_.name} categories.append(category) - instruction = node_data.instruction if node_data.instruction else '' + instruction = node_data.instruction or "" input_text = query - memory_str = '' + memory_str = "" if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) prompt_messages = [] if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) + role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) user_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_1 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) assistant_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) user_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_2 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) assistant_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) user_prompt_message_3 = ChatModelMessage( role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction) + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( + input_text=input_text, + categories=json.dumps(categories, ensure_ascii=False), + classification_instructions=instruction, + ), ) prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, - input_text=input_text, - categories=json.dumps(categories), - classification_instructions=instruction, - ensure_ascii=False) + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( + histories=memory_str, + input_text=input_text, + categories=json.dumps(categories), + classification_instructions=instruction, + ensure_ascii=False, + ) ) else: @@ -329,14 +313,12 @@ class QuestionClassifierNode(LLMNode): variable = variable_pool.get(variable_selector.value_selector) variable_value = variable.value if variable else None if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - instruction = prompt_template.format( - prompt_inputs - ) + instruction = prompt_template.format(prompt_inputs) return instruction diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index e0de148cc2..ce32b01aa4 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -1,5 +1,3 @@ - - QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Job Description', You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. @@ -14,13 +12,13 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ {histories} -""" +""" # noqa: E501 QUESTION_CLASSIFIER_USER_PROMPT_1 = """ { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" +""" # noqa: E501 QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ ```json @@ -34,7 +32,7 @@ QUESTION_CLASSIFIER_USER_PROMPT_2 = """ {"input_text": ["bad service, slow to bring the food"], "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], "classification_instructions": []} -""" +""" # noqa: E501 QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ ```json @@ -75,4 +73,4 @@ Here is the chat histories between human and assistant, inside Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index d9099a8118..e934d69fa3 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,5 +1,3 @@ - - from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,5 +6,6 @@ class TemplateTransformNodeData(BaseNodeData): """ Code Node Data. """ + variables: list[VariableSelector] - template: str \ No newline at end of file + template: str diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index b14a394a0a..32c99e0d1c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -2,13 +2,13 @@ import os from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) class TemplateTransformNode(BaseNode): @@ -24,15 +24,7 @@ class TemplateTransformNode(BaseNode): """ return { "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } + "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } def _run(self) -> NodeRunResult: @@ -51,38 +43,25 @@ class TemplateTransformNode(BaseNode): # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=node_data.template, - inputs=variables - ) - except CodeExecutionException as e: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) + language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) + except CodeExecutionError as e: + return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters", ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs={ - 'output': result['result'] - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: TemplateTransformNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -92,5 +71,6 @@ class TemplateTransformNode(BaseNode): :return: """ return { - node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 335a41061c..1a408d96cb 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -10,45 +10,46 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ToolEntity(BaseModel): provider_id: str provider_type: ToolProviderType - provider_name: str # redundancy + provider_name: str # redundancy tool_name: str - tool_label: str # redundancy + tool_label: str # redundancy tool_configurations: dict[str, Any] - @field_validator('tool_configurations', mode='before') + @field_validator("tool_configurations", mode="before") @classmethod def validate_tool_configurations(cls, value, values: ValidationInfo): if not isinstance(value, dict): - raise ValueError('tool_configurations must be a dictionary') - - for key in values.data.get('tool_configurations', {}).keys(): - value = values.data.get('tool_configurations', {}).get(key) + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}): + value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): - raise ValueError(f'{key} must be a string') - + raise ValueError(f"{key} must be a string") + return value + class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Literal['mixed', 'variable', 'constant'] + type: Literal["mixed", "variable", "constant"] - @field_validator('type', mode='before') + @field_validator("type", mode="before") @classmethod def check_type(cls, value, validation_info: ValidationInfo): typ = value - value = validation_info.data.get('value') - if typ == 'mixed' and not isinstance(value, str): - raise ValueError('value must be a string') - elif typ == 'variable': + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": if not isinstance(value, list): - raise ValueError('value must be a list') + raise ValueError("value must be a list") for val in value: if not isinstance(val, str): - raise ValueError('value must be a list of strings') - elif typ == 'constant' and not isinstance(value, str | int | float | bool): - raise ValueError('value must be a string, int, float, or bool') + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") return typ """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 1f32c7b8bd..73c22bc700 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -35,10 +35,7 @@ class ToolNode(BaseNode): node_data = cast(ToolNodeData, self.node_data) # fetch tool icon - tool_info = { - 'provider_type': node_data.provider_type.value, - 'provider_id': node_data.provider_id - } + tool_info = {"provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id} # get tool runtime try: @@ -50,10 +47,8 @@ class ToolNode(BaseNode): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to get tool runtime: {str(e)}' + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to get tool runtime: {str(e)}", ) ) return @@ -61,15 +56,13 @@ class ToolNode(BaseNode): # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] parameters = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data + tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data ) parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, - for_log=True + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=node_data, + for_log=True, ) try: @@ -86,10 +79,8 @@ class ToolNode(BaseNode): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to invoke tool: {str(e)}', + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", ) ) return @@ -126,12 +117,10 @@ class ToolNode(BaseNode): result[parameter_name] = None continue if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [ - v.to_dict() for v in self._fetch_files(variable_pool) - ] + result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)] else: tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == 'variable': + if tool_input.type == "variable": parameter_value_segment = variable_pool.get(tool_input.value) if not parameter_value_segment: raise Exception("input variable dose not exists") @@ -147,14 +136,16 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _transform_message(self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]: + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + ) -> Generator[RunEvent, None, None]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -169,66 +160,65 @@ class ToolNode(BaseNode): files: list[FileVar] = [] text = "" json: list[dict] = [] - + variables: dict[str, Any] = {} for message in message_stream: - if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - message.type == ToolInvokeMessage.MessageType.IMAGE: + if message.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta url = message.message.text ext = path.splitext(url)[1] - mimetype = message.meta.get('mime_type', 'image/jpeg') - filename = message.save_as or url.split('/')[-1] - transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) + mimetype = message.meta.get("mime_type", "image/jpeg") + filename = message.save_as or url.split("/")[-1] + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) # get tool file id - tool_file_id = url.split('/')[-1].split('.')[0] - files.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=url, - related_id=tool_file_id, - filename=filename, - extension=ext, - mime_type=mimetype, - )) + tool_file_id = url.split("/")[-1].split(".")[0] + files.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=url, + related_id=tool_file_id, + filename=filename, + extension=ext, + mime_type=mimetype, + ) + ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split('/')[-1].split('.')[0] - files.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=message.save_as, - extension=path.splitext(message.save_as)[1], - mime_type=message.meta.get('mime_type', 'application/octet-stream'), - )) + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + files.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file_id, + filename=message.save_as, + extension=path.splitext(message.save_as)[1], + mime_type=message.meta.get("mime_type", "application/octet-stream"), + ) + ) elif message.type == ToolInvokeMessage.MessageType.TEXT: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text + '\n' + text += message.message.text + "\n" yield RunStreamChunkEvent( - chunk_content=message.message.text, - from_variable_selector=[self.node_id, 'text'] + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] ) elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message, ToolInvokeMessage.JsonMessage) json.append(message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - stream_text = f'Link: {message.message.text}\n' + stream_text = f"Link: {message.message.text}\n" text += stream_text - yield RunStreamChunkEvent( - chunk_content=stream_text, - from_variable_selector=[self.node_id, 'text'] - ) + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.VARIABLE: assert isinstance(message.message, ToolInvokeMessage.VariableMessage) variable_name = message.message.variable_name @@ -241,8 +231,7 @@ class ToolNode(BaseNode): variables[variable_name] += variable_value yield RunStreamChunkEvent( - chunk_content=variable_value, - from_variable_selector=[self.node_id, variable_name] + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] ) else: variables[variable_name] = variable_value @@ -250,25 +239,15 @@ class ToolNode(BaseNode): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'text': text, - 'files': files, - 'json': json, - **variables - }, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - inputs=parameters_for_log + outputs={"text": text, "files": files, "json": json, **variables}, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + inputs=parameters_for_log, ) ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -280,18 +259,16 @@ class ToolNode(BaseNode): result = {} for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': + if input.type == "mixed": assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector - elif input.type == 'variable': + elif input.type == "variable": result[parameter_name] = input.value - elif input.type == 'constant': + elif input.type == "constant": pass - result = { - node_id + '.' + key: value for key, value in result.items() - } + result = {node_id + "." + key: value for key, value in result.items()} return result diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index e5de38dc0f..eb893a04e3 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ - - from typing import Literal, Optional from pydantic import BaseModel @@ -11,23 +9,27 @@ class AdvancedSettings(BaseModel): """ Advanced setting. """ + group_enabled: bool class Group(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] + + output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] variables: list[list[str]] group_name: str groups: list[Group] + class VariableAssignerNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'variable-assigner' + + type: str = "variable-assigner" output_type: str variables: list[list[str]] advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 6944d9e82d..f03eae257a 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -21,13 +21,9 @@ class VariableAggregatorNode(BaseNode): for selector in node_data.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs = { - "output": variable - } + outputs = {"output": variable} - inputs = { - '.'.join(selector[1:]): variable - } + inputs = {".".join(selector[1:]): variable} break else: for group in node_data.advanced_settings.groups: @@ -35,24 +31,15 @@ class VariableAggregatorNode(BaseNode): variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs[group.group_name] = { - 'output': variable - } - inputs['.'.join(selector[1:])] = variable + outputs[group.group_name] = {"output": variable} + inputs[".".join(selector[1:])] = variable break - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index d791d51523..83da4bdc79 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -2,7 +2,7 @@ from .node import VariableAssignerNode from .node_data import VariableAssignerData, WriteMode __all__ = [ - 'VariableAssignerNode', - 'VariableAssignerData', - 'WriteMode', + "VariableAssignerNode", + "VariableAssignerData", + "WriteMode", ] diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index b2f32c6aaa..3969299795 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -24,43 +24,43 @@ class VariableAssignerNode(BaseNode): # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') + raise VariableAssignerNodeError("assigned variable not found") match data.write_mode: case WriteMode.OVER_WRITE: income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) + raise VariableAssignerNodeError("input value not found") + updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') + raise VariableAssignerNodeError("input value not found") updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) + updated_variable = original_variable.model_copy(update={"value": updated_value}) case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') + raise VariableAssignerNodeError("conversation_id not found") update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ - 'value': income_value.to_object(), + "value": income_value.to_object(), }, ) @@ -72,7 +72,7 @@ def update_conversation_variable(conversation_id: str, variable: Variable): with Session(db.engine) as session: row = session.scalar(stmt) if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') + raise VariableAssignerNodeError("conversation variable not found in the database") row.data = variable.model_dump_json() session.commit() @@ -84,8 +84,8 @@ def get_zero_value(t: SegmentType): case SegmentType.OBJECT: return factory.build_segment({}) case SegmentType.STRING: - return factory.build_segment('') + return factory.build_segment("") case SegmentType.NUMBER: return factory.build_segment(0) case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') + raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index b3652b6802..8ac8eadf7c 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -6,14 +6,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' + OVER_WRITE = "over-write" + APPEND = "append" + CLEAR = "clear" class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' + title: str = "Variable Assigner" + desc: Optional[str] = "Assign a value to a variable" assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index e195730a31..b8e8b881a5 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -7,11 +7,26 @@ class Condition(BaseModel): """ Condition entity """ + variable_selector: list[str] comparison_operator: Literal[ # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", ] value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 5ff61aab3d..395ee82478 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -15,9 +15,7 @@ class ConditionProcessor: index = 0 for condition in conditions: index += 1 - actual_value = variable_pool.get_any( - condition.variable_selector - ) + actual_value = variable_pool.get_any(condition.variable_selector) expected_value = None if condition.value is not None: @@ -25,9 +23,7 @@ class ConditionProcessor: variable_selectors = variable_template_parser.extract_variable_selectors() if variable_selectors: for variable_selector in variable_selectors: - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) expected_value = variable_template_parser.format({variable_selector.variable: value}) if expected_value is None: @@ -40,7 +36,7 @@ class ConditionProcessor: { "actual_value": actual_value, "expected_value": expected_value, - "comparison_operator": comparison_operator + "comparison_operator": comparison_operator, } ) @@ -50,10 +46,10 @@ class ConditionProcessor: return input_conditions, group_result def evaluate_condition( - self, - actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], - comparison_operator: str, - expected_value: Optional[str] = None + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None, ) -> bool: """ Evaluate condition @@ -109,7 +105,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') + raise ValueError("Invalid actual value type: string or array") if expected_value not in actual_value: return False @@ -126,7 +122,7 @@ class ConditionProcessor: return True if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') + raise ValueError("Invalid actual value type: string or array") if expected_value in actual_value: return False @@ -143,7 +139,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if not actual_value.startswith(expected_value): return False @@ -160,7 +156,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if not actual_value.endswith(expected_value): return False @@ -177,7 +173,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if actual_value != expected_value: return False @@ -194,7 +190,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if actual_value == expected_value: return False @@ -231,7 +227,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -253,7 +249,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -275,7 +271,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -297,7 +293,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -308,8 +304,9 @@ class ConditionProcessor: return False return True - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], - expected_value: str | int | float) -> bool: + def _assert_greater_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: """ Assert greater than or equal :param actual_value: actual value @@ -320,7 +317,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -331,8 +328,9 @@ class ConditionProcessor: return False return True - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], - expected_value: str | int | float) -> bool: + def _assert_less_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: """ Assert less than or equal :param actual_value: actual value @@ -343,7 +341,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index c43fde172c..fd0e48b862 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -5,7 +5,7 @@ from typing import Any from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool -REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}') +REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: @@ -20,7 +20,7 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) key_selectors = filter( lambda t: len(t[1]) >= 2, - ((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)), + ((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)), ) inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} @@ -29,13 +29,13 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # return original matched string if key not found value = inputs.get(key, match.group(0)) if value is None: - value = '' + value = "" value = str(value) # remove template variables if required - return re.sub(REGEX, r'{\1}', value) + return re.sub(REGEX, r"{\1}", value) result = re.sub(REGEX, replacer, template) - result = re.sub(r'<\|.*?\|>', '', result) + result = re.sub(r"<\|.*?\|>", "", result) return result @@ -101,8 +101,8 @@ class VariableTemplateParser: """ variable_selectors = [] for variable_key in self.variable_keys: - remove_hash = variable_key.replace('#', '') - split_result = remove_hash.split('.') + remove_hash = variable_key.replace("#", "") + split_result = remove_hash.split(".") if len(split_result) < 2: continue @@ -127,7 +127,7 @@ class VariableTemplateParser: value = inputs.get(key, match.group(0)) # return original matched string if key not found if value is None: - value = '' + value = "" # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) @@ -136,7 +136,7 @@ class VariableTemplateParser: return VariableTemplateParser.remove_template_variables(value) prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str): @@ -149,4 +149,4 @@ class VariableTemplateParser: Returns: The text with template variables removed. """ - return re.sub(REGEX, r'{\1}', text) + return re.sub(REGEX, r"{\1}", text) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index a359bd606e..74a598ada5 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -6,7 +6,7 @@ from typing import Any, Optional, cast from configs import dify_config from core.app.app_config.entities import FileExtraConfig -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback @@ -33,19 +33,19 @@ logger = logging.getLogger(__name__) class WorkflowEntry: def __init__( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - workflow_type: WorkflowType, - graph_config: Mapping[str, Any], - graph: Graph, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - variable_pool: VariablePool, - thread_pool_id: Optional[str] = None + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None, ) -> None: """ Init workflow entry @@ -65,7 +65,7 @@ class WorkflowEntry: # check call depth workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) # init workflow run state self.graph_engine = GraphEngine( @@ -82,13 +82,13 @@ class WorkflowEntry: variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, - thread_pool_id=thread_pool_id + thread_pool_id=thread_pool_id, ) def run( - self, - *, - callbacks: Sequence[WorkflowCallback], + self, + *, + callbacks: Sequence[WorkflowCallback], ) -> Generator[GraphEngineEvent, None, None]: """ :param callbacks: workflow callbacks @@ -101,30 +101,20 @@ class WorkflowEntry: for event in generator: if callbacks: for callback in callbacks: - callback.on_event( - event=event - ) + callback.on_event(event=event) yield event - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except Exception as e: logger.exception("Unknown Error when workflow entry running") if callbacks: for callback in callbacks: - callback.on_event( - event=GraphRunFailedEvent( - error=str(e) - ) - ) + callback.on_event(event=GraphRunFailedEvent(error=str(e))) return @classmethod def single_step_run( - cls, - workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict + cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -137,30 +127,30 @@ class WorkflowEntry: # fetch node info from workflow graph graph = workflow.graph_dict if not graph: - raise ValueError('workflow graph not found') + raise ValueError("workflow graph not found") - nodes = graph.get('nodes') + nodes = graph.get("nodes") if not nodes: - raise ValueError('nodes not found in workflow graph') + raise ValueError("nodes not found in workflow graph") # fetch node config from node id node_config = None for node in nodes: - if node.get('id') == node_id: + if node.get("id") == node_id: node_config = node break if not node_config: - raise ValueError('node id not found in workflow graph') + raise ValueError("node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) node_cls = cast(type[BaseNode], node_cls) if not node_cls: - raise ValueError(f'Node class not found for node type {node_type}') - + raise ValueError(f"Node class not found for node type {node_type}") + # init variable pool variable_pool = VariablePool( system_variables={}, @@ -169,9 +159,7 @@ class WorkflowEntry: ) # init graph - graph = Graph.init( - graph_config=workflow.graph_dict - ) + graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state node_instance: BaseNode = node_cls( @@ -186,21 +174,17 @@ class WorkflowEntry: user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, - call_depth=0 + call_depth=0, ), graph=graph, - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter() - ) + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), ) try: # variable selector to variable mapping try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=node_config + graph_config=workflow.graph_dict, config=node_config ) except NotImplementedError: variable_mapping = {} @@ -211,7 +195,7 @@ class WorkflowEntry: variable_pool=variable_pool, tenant_id=workflow.tenant_id, node_type=node_type, - node_data=node_instance.node_data + node_data=node_instance.node_data, ) # run node @@ -219,10 +203,7 @@ class WorkflowEntry: return node_instance, generator except Exception as e: - raise WorkflowNodeRunFailedError( - node_instance=node_instance, - error=str(e) - ) + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @classmethod def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: @@ -259,21 +240,20 @@ class WorkflowEntry: variable_pool: VariablePool, tenant_id: str, node_type: NodeType, - node_data: BaseNodeData + node_data: BaseNodeData, ) -> None: for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable - node_variable_list = node_variable.split('.') + node_variable_list = node_variable.split(".") if len(node_variable_list) < 1: - raise ValueError(f'Invalid node variable {node_variable}') - - node_variable_key = '.'.join(node_variable_list[1:]) + raise ValueError(f"Invalid node variable {node_variable}") - if ( - node_variable_key not in user_inputs - and node_variable not in user_inputs - ) and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {node_variable} not found in user inputs.') + node_variable_key = ".".join(node_variable_list[1:]) + + if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( + variable_selector + ): + raise ValueError(f"Variable key {node_variable} not found in user inputs.") # fetch variable node id from variable selector variable_node_id = variable_selector[0] @@ -294,16 +274,17 @@ class WorkflowEntry: detail = node_data.vision.configs.detail if node_data.vision.configs else None for item in input_value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + if isinstance(item, dict) and "type" in item and item["type"] == "image": + transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) file = FileVar( tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get("upload_file_id") + if transfer_method == FileTransferMethod.LOCAL_FILE + else None, + extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None), ) new_value.append(file) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 72a135e73d..54f6a76e16 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import time import click from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document @@ -43,7 +43,7 @@ def handle(sender, **kwargs): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 59375b1a0b..de7c0f4dfe 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -18,8 +18,7 @@ def handle(sender, **kwargs): added_dataset_ids = dataset_ids else: old_dataset_ids = set() - for app_dataset_join in app_dataset_joins: - old_dataset_ids.add(app_dataset_join.dataset_id) + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 333b85ecb2..c5e98e263f 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -22,8 +22,7 @@ def handle(sender, **kwargs): added_dataset_ids = dataset_ids else: old_dataset_ids = set() - for app_dataset_join in app_dataset_joins: - old_dataset_ids.add(app_dataset_join.dataset_id) + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 3b7b0a37f4..c2dc736038 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -1,18 +1,29 @@ import openai import sentry_sdk +from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException +def before_send(event, hint): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if parse_error.defaultErrorResponse in str(exc_value): + return None + + return event + + def init_app(app): if app.config.get("SENTRY_DSN"): sentry_sdk.init( dsn=app.config.get("SENTRY_DSN"), integrations=[FlaskIntegration(), CeleryIntegration()], - ignore_errors=[HTTPException, ValueError, openai.APIStatusError], + ignore_errors=[HTTPException, ValueError, openai.APIStatusError, parse_error.defaultErrorResponse], traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), environment=app.config.get("DEPLOY_ENV"), release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", + before_send=before_send, ) diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index 9ed1fcf0b4..c42f946fa8 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -5,7 +5,7 @@ from collections.abc import Generator from contextlib import closing from flask import Flask -from google.cloud import storage as GoogleCloudStorage +from google.cloud import storage as google_cloud_storage from extensions.storage.base_storage import BaseStorage @@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage): service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) - self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) + self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) else: - self.client = GoogleCloudStorage.Client() + self.client = google_cloud_storage.Client() def save(self, filename, data): bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py index 46ee4bf80f..f833ae85dc 100644 --- a/api/extensions/storage/local_storage.py +++ b/api/extensions/storage/local_storage.py @@ -1,6 +1,7 @@ import os import shutil from collections.abc import Generator +from pathlib import Path from flask import Flask @@ -26,8 +27,7 @@ class LocalStorage(BaseStorage): folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) - with open(os.path.join(os.getcwd(), filename), "wb") as f: - f.write(data) + Path(os.path.join(os.getcwd(), filename)).write_bytes(data) def load_once(self, filename: str) -> bytes: if not self.folder or self.folder.endswith("/"): @@ -38,9 +38,7 @@ class LocalStorage(BaseStorage): if not os.path.exists(filename): raise FileNotFoundError("File not found") - with open(filename, "rb") as f: - data = f.read() - + data = Path(filename).read_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 2d306edb40..83f9c74e33 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord from Crypto.Util.strxor import strxor -class PKCS1OAEP_Cipher: +class PKCS1OAepCipher: """Cipher object for PKCS#1 v1.5 OAEP. Do not create directly: use :func:`new` instead.""" @@ -204,7 +204,8 @@ class PKCS1OAEP_Cipher: def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): - """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption. + """Return a cipher object :class:`PKCS1OAEP_Cipher` + that can be used to perform PKCS#1 OAEP encryption or decryption. :param key: The key object to use to encrypt or decrypt the message. @@ -237,4 +238,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): if randfunc is None: randfunc = Random.get_random_bytes - return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) + return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) diff --git a/api/libs/helper.py b/api/libs/helper.py index bb3ac302e3..9fad17f872 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -73,7 +73,6 @@ def alphanumeric(value: str): raise ValueError(f"{value} is not a valid alphanumeric value") - def timestamp_value(timestamp): try: int_timestamp = int(timestamp) @@ -85,7 +84,7 @@ def timestamp_value(timestamp): raise ValueError(error) -class str_len: +class StrLen: """Restrict input to an integer in a range (inclusive)""" def __init__(self, max_length, argument="argument"): @@ -103,7 +102,7 @@ class str_len: return value -class float_range: +class FloatRange: """Restrict input to an float in a range (inclusive)""" def __init__(self, low, high, argument="argument"): @@ -122,7 +121,7 @@ class float_range: return value -class datetime_string: +class DatetimeString: def __init__(self, format, argument="argument"): self.format = format self.argument = argument @@ -146,7 +145,6 @@ def _get_float(value): raise ValueError("{} is not a valid float".format(value)) - def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string @@ -157,7 +155,7 @@ def timezone(timezone_string): def generate_string(n): letters_digits = string.ascii_letters + string.digits - result = '' + result = "" for i in range(n): result += random.choice(letters_digits) @@ -224,7 +222,7 @@ class TokenManager: key = cls._get_token_key(token, token_type) token_data_json = redis_client.get(key) if token_data_json is None: - logging.warning(f'{token_type} token {token} not found with key {key}') + logging.warning(f"{token_type} token {token} not found with key {key}") return None token_data = json.loads(token_data_json) return token_data @@ -252,7 +250,7 @@ class RateLimiter: self.time_window = time_window def _get_key(self, email: str) -> str: - return f'{self.prefix}:{email}' + return f"{self.prefix}:{email}" def is_rate_limited(self, email: str) -> bool: key = self._get_key(email) diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41d6905899..185ff3f95e 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,6 +1,6 @@ import json -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError def parse_json_markdown(json_string: str) -> dict: @@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"Got invalid JSON object. Error: {e}") for key in expected_keys: if key not in json_obj: - raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" + raise OutputParserError( + f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 6da1a6d39b..05a73b09b7 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -158,7 +158,7 @@ class NotionOAuth(OAuthDataSource): page_icon = page_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: @@ -191,7 +191,7 @@ class NotionOAuth(OAuthDataSource): page_icon = database_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: diff --git a/api/libs/rsa.py b/api/libs/rsa.py index a578bf3e56..637bcc4a1d 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -4,9 +4,9 @@ from Crypto.Cipher import AES from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes -import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from libs import gmpy2_pkcs10aep_cipher def generate_key_pair(tenant_id): diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py index 0fba6a87eb..8cd4ec552b 100644 --- a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -24,6 +24,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('label') diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py index bfda7d619c..92f41f0abd 100644 --- a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -21,6 +21,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('version') diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index 2365766837..fcca705d21 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -99,7 +99,7 @@ def upgrade(): id=id, tenant_id=tenant_id, user_id=user_id, - provider='google', + provider='google', encrypted_credentials=encrypted_credentials, created_at=created_at, updated_at=updated_at diff --git a/api/models/__init__.py b/api/models/__init__.py index 4012611471..30ceef057e 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -4,7 +4,7 @@ from .model import App, AppMode, Message from .types import StringUUID from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus -__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] +__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"] class CreatedByRole(Enum): @@ -12,11 +12,11 @@ class CreatedByRole(Enum): Enum class for createdByRole """ - ACCOUNT = 'account' - END_USER = 'end_user' + ACCOUNT = "account" + END_USER = "end_user" @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': + def value_of(cls, value: str) -> "CreatedByRole": """ Get value of given mode. @@ -26,4 +26,4 @@ class CreatedByRole(Enum): for role in cls: if role.value == value: return role - raise ValueError(f'invalid createdByRole value {value}') + raise ValueError(f"invalid createdByRole value {value}") diff --git a/api/models/account.py b/api/models/account.py index 67d940b7b7..60b4f11aad 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -9,21 +9,18 @@ from .types import StringUUID class AccountStatus(str, enum.Enum): - PENDING = 'pending' - UNINITIALIZED = 'uninitialized' - ACTIVE = 'active' - BANNED = 'banned' - CLOSED = 'closed' + PENDING = "pending" + UNINITIALIZED = "uninitialized" + ACTIVE = "active" + BANNED = "banned" + CLOSED = "closed" class Account(UserMixin, db.Model): - __tablename__ = 'accounts' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_pkey'), - db.Index('account_email_idx', 'email') - ) + __tablename__ = "accounts" + __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -34,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def is_password_set(self): @@ -65,11 +62,13 @@ class Account(UserMixin, db.Model): @current_tenant_id.setter def current_tenant_id(self, value: str): try: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == value) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.account_id == self.id) \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == value) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == self.id) .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join @@ -91,20 +90,18 @@ class Account(UserMixin, db.Model): @classmethod def get_by_openid(cls, provider: str, open_id: str) -> db.Model: - account_integrate = db.session.query(AccountIntegrate). \ - filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id). \ - one_or_none() + account_integrate = ( + db.session.query(AccountIntegrate) + .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .one_or_none() + ) if account_integrate: - return db.session.query(Account). \ - filter(Account.id == account_integrate.account_id). \ - one_or_none() + return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return None def get_integrates(self) -> list[db.Model]: ai = db.Model - return db.session.query(ai).filter( - ai.account_id == self.id - ).all() + return db.session.query(ai).filter(ai.account_id == self.id).all() # check current_user.current_tenant.current_role in ['admin', 'owner'] @property @@ -123,61 +120,75 @@ class Account(UserMixin, db.Model): def is_dataset_operator(self): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR + class TenantStatus(str, enum.Enum): - NORMAL = 'normal' - ARCHIVE = 'archive' + NORMAL = "normal" + ARCHIVE = "archive" class TenantAccountRole(str, enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - EDITOR = 'editor' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + EDITOR = "editor" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" @staticmethod def is_valid_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.NORMAL, TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } @staticmethod def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} - + @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, - TenantAccountRole.DATASET_OPERATOR} - + return role and role in { + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + @staticmethod def is_editing_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.DATASET_OPERATOR, + } + class Tenant(db.Model): - __tablename__ = 'tenants' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_pkey'), - ) + __tablename__ = "tenants" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def get_accounts(self) -> list[Account]: - return db.session.query(Account).filter( - Account.id == TenantAccountJoin.account_id, - TenantAccountJoin.tenant_id == self.id - ).all() + return ( + db.session.query(Account) + .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .all() + ) @property def custom_config_dict(self) -> dict: @@ -189,54 +200,54 @@ class Tenant(db.Model): class TenantAccountJoinRole(enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" class TenantAccountJoin(db.Model): - __tablename__ = 'tenant_account_joins' + __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), - db.Index('tenant_account_join_account_id_idx', 'account_id'), - db.Index('tenant_account_join_tenant_id_idx', 'tenant_id'), - db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + db.Index("tenant_account_join_account_id_idx", "account_id"), + db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) - current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - role = db.Column(db.String(16), nullable=False, server_default='normal') + current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AccountIntegrate(db.Model): - __tablename__ = 'account_integrates' + __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_integrate_pkey'), - db.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), - db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class InvitationCode(db.Model): - __tablename__ = 'invitation_codes' + __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='invitation_code_pkey'), - db.Index('invitation_codes_batch_idx', 'batch'), - db.Index('invitation_codes_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + db.Index("invitation_codes_batch_idx", "batch"), + db.Index("invitation_codes_code_idx", "code", "status"), ) id = db.Column(db.Integer, nullable=False) @@ -247,4 +258,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 7f69323628..97173747af 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -6,22 +6,22 @@ from .types import StringUUID class APIBasedExtensionPoint(enum.Enum): - APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' - PING = 'ping' - APP_MODERATION_INPUT = 'app.moderation.input' - APP_MODERATION_OUTPUT = 'app.moderation.output' + APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" + PING = "ping" + APP_MODERATION_INPUT = "app.moderation.input" + APP_MODERATION_OUTPUT = "app.moderation.output" class APIBasedExtension(db.Model): - __tablename__ = 'api_based_extensions' + __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), - db.Index('api_based_extension_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index bf3f12a2c5..a2d2a3454d 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -24,37 +24,34 @@ from .types import StringUUID class DatasetPermissionEnum(str, enum.Enum): - ONLY_ME = 'only_me' - ALL_TEAM = 'all_team_members' - PARTIAL_TEAM = 'partial_members' + ONLY_ME = "only_me" + ALL_TEAM = "all_team_members" + PARTIAL_TEAM = "partial_members" + class Dataset(db.Model): - __tablename__ = 'datasets' + __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_pkey'), - db.Index('dataset_tenant_idx', 'tenant_id'), - db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="dataset_pkey"), + db.Index("dataset_tenant_idx", "tenant_id"), + db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) - INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] + INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, - server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, - server_default=db.text("'only_me'::character varying")) + provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) + permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) data_source_type = db.Column(db.String(255)) indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -62,8 +59,9 @@ class Dataset(db.Model): @property def dataset_keyword_table(self): - dataset_keyword_table = db.session.query(DatasetKeywordTable).filter( - DatasetKeywordTable.dataset_id == self.id).first() + dataset_keyword_table = ( + db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + ) if dataset_keyword_table: return dataset_keyword_table @@ -79,13 +77,19 @@ class Dataset(db.Model): @property def latest_process_rule(self): - return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \ - .order_by(DatasetProcessRule.created_at.desc()).first() + return ( + DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + .order_by(DatasetProcessRule.created_at.desc()) + .first() + ) @property def app_count(self): - return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id, - App.id == AppDatasetJoin.app_id).scalar() + return ( + db.session.query(func.count(AppDatasetJoin.id)) + .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .scalar() + ) @property def document_count(self): @@ -93,30 +97,40 @@ class Dataset(db.Model): @property def available_document_count(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.id, - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False - ).scalar() + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def available_segment_count(self): - return db.session.query(func.count(DocumentSegment.id)).filter( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).scalar() + return ( + db.session.query(func.count(DocumentSegment.id)) + .filter( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .scalar() + ) @property def word_count(self): - return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ - .filter(Document.dataset_id == self.id).scalar() + return ( + Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + .filter(Document.dataset_id == self.id) + .scalar() + ) @property def doc_form(self): - document = db.session.query(Document).filter( - Document.dataset_id == self.id).first() + document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -124,76 +138,68 @@ class Dataset(db.Model): @property def retrieval_model_dict(self): default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - return self.retrieval_model if self.retrieval_model else default_retrieval_model + return self.retrieval_model or default_retrieval_model @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'knowledge' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "knowledge", + ) + .all() + ) - return tags if tags else [] + return tags or [] @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f'Vector_index_{normalized_dataset_id}_Node' + return f"Vector_index_{normalized_dataset_id}_Node" class DatasetProcessRule(db.Model): - __tablename__ = 'dataset_process_rules' + __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'), - db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, - server_default=db.text("'automatic'::character varying")) + mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - MODES = ['automatic', 'custom'] - PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails'] + MODES = ["automatic", "custom"] + PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES = { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } def to_dict(self): return { - 'id': self.id, - 'dataset_id': self.dataset_id, - 'mode': self.mode, - 'rules': self.rules_dict, - 'created_by': self.created_by, - 'created_at': self.created_at, + "id": self.id, + "dataset_id": self.dataset_id, + "mode": self.mode, + "rules": self.rules_dict, + "created_by": self.created_by, + "created_at": self.created_at, } @property @@ -205,17 +211,16 @@ class DatasetProcessRule(db.Model): class Document(db.Model): - __tablename__ = 'documents' + __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_pkey'), - db.Index('document_dataset_id_idx', 'dataset_id'), - db.Index('document_is_paused_idx', 'is_paused'), - db.Index('document_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_pkey"), + db.Index("document_dataset_id_idx", "dataset_id"), + db.Index("document_is_paused_idx", "is_paused"), + db.Index("document_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) @@ -227,8 +232,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -250,7 +254,7 @@ class Document(db.Model): completed_at = db.Column(db.DateTime, nullable=True) # pause - is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = db.Column(StringUUID, nullable=True) paused_at = db.Column(db.DateTime, nullable=True) @@ -259,44 +263,39 @@ class Document(db.Model): stopped_at = db.Column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String( - 255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - archived = db.Column(db.Boolean, nullable=False, - server_default=db.text('false')) + archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) - doc_form = db.Column(db.String( - 255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) - DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] + DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @property def display_status(self): status = None - if self.indexing_status == 'waiting': - status = 'queuing' - elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused: - status = 'paused' - elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']: - status = 'indexing' - elif self.indexing_status == 'error': - status = 'error' - elif self.indexing_status == 'completed' and not self.archived and self.enabled: - status = 'available' - elif self.indexing_status == 'completed' and not self.archived and not self.enabled: - status = 'disabled' - elif self.indexing_status == 'completed' and self.archived: - status = 'archived' + if self.indexing_status == "waiting": + status = "queuing" + elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: + status = "paused" + elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: + status = "indexing" + elif self.indexing_status == "error": + status = "error" + elif self.indexing_status == "completed" and not self.archived and self.enabled: + status = "available" + elif self.indexing_status == "completed" and not self.archived and not self.enabled: + status = "disabled" + elif self.indexing_status == "completed" and self.archived: + status = "archived" return status @property @@ -313,24 +312,26 @@ class Document(db.Model): @property def data_source_detail_dict(self): if self.data_source_info: - if self.data_source_type == 'upload_file': + if self.data_source_type == "upload_file": data_source_info_dict = json.loads(self.data_source_info) - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info_dict['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .one_or_none() + ) if file_detail: return { - 'upload_file': { - 'id': file_detail.id, - 'name': file_detail.name, - 'size': file_detail.size, - 'extension': file_detail.extension, - 'mime_type': file_detail.mime_type, - 'created_by': file_detail.created_by, - 'created_at': file_detail.created_at.timestamp() + "upload_file": { + "id": file_detail.id, + "name": file_detail.name, + "size": file_detail.size, + "extension": file_detail.extension, + "mime_type": file_detail.mime_type, + "created_by": file_detail.created_by, + "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': + elif self.data_source_type in {"notion_import", "website_crawl"}: return json.loads(self.data_source_info) return {} @@ -356,120 +357,123 @@ class Document(db.Model): @property def hit_count(self): - return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ - .filter(DocumentSegment.document_id == self.id).scalar() + return ( + DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + .filter(DocumentSegment.document_id == self.id) + .scalar() + ) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'dataset_id': self.dataset_id, - 'position': self.position, - 'data_source_type': self.data_source_type, - 'data_source_info': self.data_source_info, - 'dataset_process_rule_id': self.dataset_process_rule_id, - 'batch': self.batch, - 'name': self.name, - 'created_from': self.created_from, - 'created_by': self.created_by, - 'created_api_request_id': self.created_api_request_id, - 'created_at': self.created_at, - 'processing_started_at': self.processing_started_at, - 'file_id': self.file_id, - 'word_count': self.word_count, - 'parsing_completed_at': self.parsing_completed_at, - 'cleaning_completed_at': self.cleaning_completed_at, - 'splitting_completed_at': self.splitting_completed_at, - 'tokens': self.tokens, - 'indexing_latency': self.indexing_latency, - 'completed_at': self.completed_at, - 'is_paused': self.is_paused, - 'paused_by': self.paused_by, - 'paused_at': self.paused_at, - 'error': self.error, - 'stopped_at': self.stopped_at, - 'indexing_status': self.indexing_status, - 'enabled': self.enabled, - 'disabled_at': self.disabled_at, - 'disabled_by': self.disabled_by, - 'archived': self.archived, - 'archived_reason': self.archived_reason, - 'archived_by': self.archived_by, - 'archived_at': self.archived_at, - 'updated_at': self.updated_at, - 'doc_type': self.doc_type, - 'doc_metadata': self.doc_metadata, - 'doc_form': self.doc_form, - 'doc_language': self.doc_language, - 'display_status': self.display_status, - 'data_source_info_dict': self.data_source_info_dict, - 'average_segment_length': self.average_segment_length, - 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - 'dataset': self.dataset.to_dict() if self.dataset else None, - 'segment_count': self.segment_count, - 'hit_count': self.hit_count + "id": self.id, + "tenant_id": self.tenant_id, + "dataset_id": self.dataset_id, + "position": self.position, + "data_source_type": self.data_source_type, + "data_source_info": self.data_source_info, + "dataset_process_rule_id": self.dataset_process_rule_id, + "batch": self.batch, + "name": self.name, + "created_from": self.created_from, + "created_by": self.created_by, + "created_api_request_id": self.created_api_request_id, + "created_at": self.created_at, + "processing_started_at": self.processing_started_at, + "file_id": self.file_id, + "word_count": self.word_count, + "parsing_completed_at": self.parsing_completed_at, + "cleaning_completed_at": self.cleaning_completed_at, + "splitting_completed_at": self.splitting_completed_at, + "tokens": self.tokens, + "indexing_latency": self.indexing_latency, + "completed_at": self.completed_at, + "is_paused": self.is_paused, + "paused_by": self.paused_by, + "paused_at": self.paused_at, + "error": self.error, + "stopped_at": self.stopped_at, + "indexing_status": self.indexing_status, + "enabled": self.enabled, + "disabled_at": self.disabled_at, + "disabled_by": self.disabled_by, + "archived": self.archived, + "archived_reason": self.archived_reason, + "archived_by": self.archived_by, + "archived_at": self.archived_at, + "updated_at": self.updated_at, + "doc_type": self.doc_type, + "doc_metadata": self.doc_metadata, + "doc_form": self.doc_form, + "doc_language": self.doc_language, + "display_status": self.display_status, + "data_source_info_dict": self.data_source_info_dict, + "average_segment_length": self.average_segment_length, + "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, + "dataset": self.dataset.to_dict() if self.dataset else None, + "segment_count": self.segment_count, + "hit_count": self.hit_count, } @classmethod def from_dict(cls, data: dict): return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - dataset_id=data.get('dataset_id'), - position=data.get('position'), - data_source_type=data.get('data_source_type'), - data_source_info=data.get('data_source_info'), - dataset_process_rule_id=data.get('dataset_process_rule_id'), - batch=data.get('batch'), - name=data.get('name'), - created_from=data.get('created_from'), - created_by=data.get('created_by'), - created_api_request_id=data.get('created_api_request_id'), - created_at=data.get('created_at'), - processing_started_at=data.get('processing_started_at'), - file_id=data.get('file_id'), - word_count=data.get('word_count'), - parsing_completed_at=data.get('parsing_completed_at'), - cleaning_completed_at=data.get('cleaning_completed_at'), - splitting_completed_at=data.get('splitting_completed_at'), - tokens=data.get('tokens'), - indexing_latency=data.get('indexing_latency'), - completed_at=data.get('completed_at'), - is_paused=data.get('is_paused'), - paused_by=data.get('paused_by'), - paused_at=data.get('paused_at'), - error=data.get('error'), - stopped_at=data.get('stopped_at'), - indexing_status=data.get('indexing_status'), - enabled=data.get('enabled'), - disabled_at=data.get('disabled_at'), - disabled_by=data.get('disabled_by'), - archived=data.get('archived'), - archived_reason=data.get('archived_reason'), - archived_by=data.get('archived_by'), - archived_at=data.get('archived_at'), - updated_at=data.get('updated_at'), - doc_type=data.get('doc_type'), - doc_metadata=data.get('doc_metadata'), - doc_form=data.get('doc_form'), - doc_language=data.get('doc_language') + id=data.get("id"), + tenant_id=data.get("tenant_id"), + dataset_id=data.get("dataset_id"), + position=data.get("position"), + data_source_type=data.get("data_source_type"), + data_source_info=data.get("data_source_info"), + dataset_process_rule_id=data.get("dataset_process_rule_id"), + batch=data.get("batch"), + name=data.get("name"), + created_from=data.get("created_from"), + created_by=data.get("created_by"), + created_api_request_id=data.get("created_api_request_id"), + created_at=data.get("created_at"), + processing_started_at=data.get("processing_started_at"), + file_id=data.get("file_id"), + word_count=data.get("word_count"), + parsing_completed_at=data.get("parsing_completed_at"), + cleaning_completed_at=data.get("cleaning_completed_at"), + splitting_completed_at=data.get("splitting_completed_at"), + tokens=data.get("tokens"), + indexing_latency=data.get("indexing_latency"), + completed_at=data.get("completed_at"), + is_paused=data.get("is_paused"), + paused_by=data.get("paused_by"), + paused_at=data.get("paused_at"), + error=data.get("error"), + stopped_at=data.get("stopped_at"), + indexing_status=data.get("indexing_status"), + enabled=data.get("enabled"), + disabled_at=data.get("disabled_at"), + disabled_by=data.get("disabled_by"), + archived=data.get("archived"), + archived_reason=data.get("archived_reason"), + archived_by=data.get("archived_by"), + archived_at=data.get("archived_at"), + updated_at=data.get("updated_at"), + doc_type=data.get("doc_type"), + doc_metadata=data.get("doc_metadata"), + doc_form=data.get("doc_form"), + doc_language=data.get("doc_language"), ) + class DocumentSegment(db.Model): - __tablename__ = 'document_segments' + __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_segment_pkey'), - db.Index('document_segment_dataset_id_idx', 'dataset_id'), - db.Index('document_segment_document_id_idx', 'document_id'), - db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), - db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), - db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), - db.Index('document_segment_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_segment_pkey"), + db.Index("document_segment_dataset_id_idx", "dataset_id"), + db.Index("document_segment_document_id_idx", "document_id"), + db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), + db.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -486,18 +490,14 @@ class DocumentSegment(db.Model): # basic fields hit_count = db.Column(db.Integer, nullable=False, default=0) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, - server_default=db.text("'waiting'::character varying")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -513,17 +513,19 @@ class DocumentSegment(db.Model): @property def previous_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position - 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) + .first() + ) @property def next_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position + 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) + .first() + ) def get_sign_content(self): pattern = r"/files/([a-f0-9\-]+)/image-preview" @@ -535,7 +537,7 @@ class DocumentSegment(db.Model): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -546,21 +548,20 @@ class DocumentSegment(db.Model): # Reconstruct the text with signed URLs offset = 0 for start, end, signed_url in signed_urls: - text = text[:start + offset] + signed_url + text[end + offset:] + text = text[: start + offset] + signed_url + text[end + offset :] offset += len(signed_url) - (end - start) return text - class AppDatasetJoin(db.Model): - __tablename__ = 'app_dataset_joins' + __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'), - db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -571,13 +572,13 @@ class AppDatasetJoin(db.Model): class DatasetQuery(db.Model): - __tablename__ = 'dataset_queries' + __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_query_pkey'), - db.Index('dataset_query_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) source = db.Column(db.String(255), nullable=False) @@ -588,17 +589,18 @@ class DatasetQuery(db.Model): class DatasetKeywordTable(db.Model): - __tablename__ = 'dataset_keyword_tables' + __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), - db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False, - server_default=db.text("'database'::character varying")) + data_source_type = db.Column( + db.String(255), nullable=False, server_default=db.text("'database'::character varying") + ) @property def keyword_table_dict(self): @@ -614,19 +616,17 @@ class DatasetKeywordTable(db.Model): return dct # get dataset - dataset = Dataset.query.filter_by( - id=self.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=self.dataset_id).first() if not dataset: return None - if self.data_source_type == 'database': + if self.data_source_type == "database": return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None else: - file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt' + file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" try: keyword_table_text = storage.load_once(file_key) if keyword_table_text: - return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder) + return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) return None except Exception as e: logging.exception(str(e)) @@ -634,21 +634,21 @@ class DatasetKeywordTable(db.Model): class Embedding(db.Model): - __tablename__ = 'embeddings' + __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'), - db.Index('created_at_idx', 'created_at') + db.PrimaryKeyConstraint("id", name="embedding_pkey"), + db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - model_name = db.Column(db.String(255), nullable=False, - server_default=db.text("'text-embedding-ada-002'::character varying")) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = db.Column( + db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - provider_name = db.Column(db.String(255), nullable=False, - server_default=db.text("''::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -658,33 +658,32 @@ class Embedding(db.Model): class DatasetCollectionBinding(db.Model): - __tablename__ = 'dataset_collection_bindings' + __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), - db.Index('provider_model_name_idx', 'provider_name', 'model_name') - + db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class DatasetPermission(db.Model): - __tablename__ = 'dataset_permissions' + __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'), - db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'), - db.Index('idx_dataset_permissions_account_id', 'account_id'), - db.Index('idx_dataset_permissions_tenant_id', 'tenant_id') + db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + db.Index("idx_dataset_permissions_account_id", "account_id"), + db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) dataset_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) - has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/model.py b/api/models/model.py index 3eec2539d3..f4e7686849 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -21,25 +21,23 @@ from .types import StringUUID class DifySetup(db.Model): - __tablename__ = 'dify_setups' - __table_args__ = ( - db.PrimaryKeyConstraint('version', name='dify_setup_pkey'), - ) + __tablename__ = "dify_setups" + __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AppMode(Enum): - COMPLETION = 'completion' - WORKFLOW = 'workflow' - CHAT = 'chat' - ADVANCED_CHAT = 'advanced-chat' - AGENT_CHAT = 'agent-chat' - CHANNEL = 'channel' + COMPLETION = "completion" + WORKFLOW = "workflow" + CHAT = "chat" + ADVANCED_CHAT = "advanced-chat" + AGENT_CHAT = "agent-chat" + CHANNEL = "channel" @classmethod - def value_of(cls, value: str) -> 'AppMode': + def value_of(cls, value: str) -> "AppMode": """ Get value of given mode. @@ -49,21 +47,19 @@ class AppMode(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class IconType(Enum): IMAGE = "image" EMOJI = "emoji" -class App(db.Model): - __tablename__ = 'apps' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_pkey'), - db.Index('app_tenant_id_idx', 'tenant_id') - ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) +class App(db.Model): + __tablename__ = "apps" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) @@ -76,17 +72,17 @@ class App(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -98,7 +94,7 @@ class App(db.Model): if app_model_config: return app_model_config.pre_prompt else: - return '' + return "" @property def site(self): @@ -106,24 +102,24 @@ class App(db.Model): return site @property - def app_model_config(self) -> Optional['AppModelConfig']: + def app_model_config(self) -> Optional["AppModelConfig"]: if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional['Workflow']: + def workflow(self) -> Optional["Workflow"]: if self.workflow_id: from .workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return None @property def api_base_url(self): - return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property def tenant(self): @@ -137,8 +133,9 @@ class App(db.Model): return False if not app_model_config.agent_mode: return False - if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( + "strategy", "" + ) in {"function_call", "react"}: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -160,16 +157,16 @@ class App(db.Model): if not app_model_config.agent_mode: return [] agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) + tools = agent_mode.get("tools", []) provider_ids = [] for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api': + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api": # check if provider id is a uuid string, if not, skip try: uuid.UUID(provider_id) @@ -181,8 +178,7 @@ class App(db.Model): return [] api_providers = db.session.execute( - text('SELECT id FROM tool_api_providers WHERE id IN :provider_ids'), - {'provider_ids': tuple(provider_ids)} + text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} ).fetchall() deleted_tools = [] @@ -191,44 +187,43 @@ class App(db.Model): for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api' and provider_id not in current_api_provider_ids: - deleted_tools.append(tool['tool_name']) + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api" and provider_id not in current_api_provider_ids: + deleted_tools.append(tool["tool_name"]) return deleted_tools @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'app' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "app", + ) + .all() + ) - return tags if tags else [] + return tags or [] class AppModelConfig(Base): - __tablename__ = 'app_model_configs' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), - db.Index('app_app_id_idx', 'app_id') - ) + __tablename__ = "app_model_configs" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -264,28 +259,29 @@ class AppModelConfig(Base): @property def suggested_questions_after_answer_dict(self) -> dict: - return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \ + return ( + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer else {"enabled": False} + ) @property def speech_to_text_dict(self) -> dict: - return json.loads(self.speech_to_text) if self.speech_to_text \ - else {"enabled": False} + return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} @property def text_to_speech_dict(self) -> dict: - return json.loads(self.text_to_speech) if self.text_to_speech \ - else {"enabled": False} + return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} @property def retriever_resource_dict(self) -> dict: - return json.loads(self.retriever_resource) if self.retriever_resource \ - else {"enabled": True} + return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} @property def annotation_reply_dict(self) -> dict: - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == self.app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -294,8 +290,8 @@ class AppModelConfig(Base): "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } else: @@ -307,13 +303,15 @@ class AppModelConfig(Base): @property def sensitive_word_avoidance_dict(self) -> dict: - return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ + return ( + json.loads(self.sensitive_word_avoidance) + if self.sensitive_word_avoidance else {"enabled": False, "type": "", "configs": []} + ) @property def external_data_tools_list(self) -> list[dict]: - return json.loads(self.external_data_tools) if self.external_data_tools \ - else [] + return json.loads(self.external_data_tools) if self.external_data_tools else [] @property def user_input_form_list(self) -> dict: @@ -321,8 +319,11 @@ class AppModelConfig(Base): @property def agent_mode_dict(self) -> dict: - return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], - "prompt": None} + return ( + json.loads(self.agent_mode) + if self.agent_mode + else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + ) @property def chat_prompt_config_dict(self) -> dict: @@ -336,19 +337,28 @@ class AppModelConfig(Base): def dataset_configs_dict(self) -> dict: if self.dataset_configs: dataset_configs = json.loads(self.dataset_configs) - if 'retrieval_model' not in dataset_configs: - return {'retrieval_model': 'single'} + if "retrieval_model" not in dataset_configs: + return {"retrieval_model": "single"} else: return dataset_configs return { - 'retrieval_model': 'multiple', - } + "retrieval_model": "multiple", + } @property def file_upload_dict(self) -> dict: - return json.loads(self.file_upload) if self.file_upload else { - "image": {"enabled": False, "number_limits": 3, "detail": "high", - "transfer_methods": ["remote_url", "local_file"]}} + return ( + json.loads(self.file_upload) + if self.file_upload + else { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + } + ) def to_dict(self) -> dict: return { @@ -371,44 +381,53 @@ class AppModelConfig(Base): "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, - "file_upload": self.file_upload_dict + "file_upload": self.file_upload_dict, } def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config.get('opening_statement') - self.suggested_questions = json.dumps(model_config['suggested_questions']) \ - if model_config.get('suggested_questions') else None - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ - if model_config.get('suggested_questions_after_answer') else None - self.speech_to_text = json.dumps(model_config['speech_to_text']) \ - if model_config.get('speech_to_text') else None - self.text_to_speech = json.dumps(model_config['text_to_speech']) \ - if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) \ - if model_config.get('more_like_this') else None - self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ - if model_config.get('sensitive_word_avoidance') else None - self.external_data_tools = json.dumps(model_config['external_data_tools']) \ - if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) \ - if model_config.get('model') else None - self.user_input_form = json.dumps(model_config['user_input_form']) \ - if model_config.get('user_input_form') else None - self.dataset_query_variable = model_config.get('dataset_query_variable') - self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) \ - if model_config.get('agent_mode') else None - self.retriever_resource = json.dumps(model_config['retriever_resource']) \ - if model_config.get('retriever_resource') else None - self.prompt_type = model_config.get('prompt_type', 'simple') - self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \ - if model_config.get('chat_prompt_config') else None - self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \ - if model_config.get('completion_prompt_config') else None - self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ - if model_config.get('dataset_configs') else None - self.file_upload = json.dumps(model_config.get('file_upload')) \ - if model_config.get('file_upload') else None + self.opening_statement = model_config.get("opening_statement") + self.suggested_questions = ( + json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + ) + self.suggested_questions_after_answer = ( + json.dumps(model_config["suggested_questions_after_answer"]) + if model_config.get("suggested_questions_after_answer") + else None + ) + self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None + self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None + self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.sensitive_word_avoidance = ( + json.dumps(model_config["sensitive_word_avoidance"]) + if model_config.get("sensitive_word_avoidance") + else None + ) + self.external_data_tools = ( + json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + ) + self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.user_input_form = ( + json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + ) + self.dataset_query_variable = model_config.get("dataset_query_variable") + self.pre_prompt = model_config["pre_prompt"] + self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.retriever_resource = ( + json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + ) + self.prompt_type = model_config.get("prompt_type", "simple") + self.chat_prompt_config = ( + json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None + ) + self.completion_prompt_config = ( + json.dumps(model_config.get("completion_prompt_config")) + if model_config.get("completion_prompt_config") + else None + ) + self.dataset_configs = ( + json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None + ) + self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None return self def copy(self): @@ -433,21 +452,21 @@ class AppModelConfig(Base): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload + file_upload=self.file_upload, ) return new_app_model_config class RecommendedApp(db.Model): - __tablename__ = 'recommended_apps' + __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='recommended_app_pkey'), - db.Index('recommended_app_app_id_idx', 'app_id'), - db.Index('recommended_app_is_listed_idx', 'is_listed', 'language') + db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + db.Index("recommended_app_app_id_idx", "app_id"), + db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) @@ -458,8 +477,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -468,22 +487,22 @@ class RecommendedApp(db.Model): class InstalledApp(db.Model): - __tablename__ = 'installed_apps' + __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='installed_app_pkey'), - db.Index('installed_app_tenant_id_idx', 'tenant_id'), - db.Index('installed_app_app_id_idx', 'app_id'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + db.PrimaryKeyConstraint("id", name="installed_app_pkey"), + db.Index("installed_app_tenant_id_idx", "tenant_id"), + db.Index("installed_app_app_id_idx", "app_id"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) app_owner_tenant_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False, default=0) - is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -497,13 +516,13 @@ class InstalledApp(db.Model): class Conversation(Base): - __tablename__ = 'conversations' + __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='conversation_pkey'), - db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') + db.PrimaryKeyConstraint("id", name="conversation_pkey"), + db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -515,7 +534,7 @@ class Conversation(Base): inputs = db.Column(db.JSON) introduction = db.Column(db.Text) system_instruction = db.Column(db.Text) - system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) @@ -524,13 +543,17 @@ class Conversation(Base): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - messages: Mapped[list["Message"]] = relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations: Mapped[list["MessageAnnotation"]] = relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + messages: Mapped[list["Message"]] = relationship( + "Message", backref="conversation", lazy="select", passive_deletes="all" + ) + message_annotations: Mapped[list["MessageAnnotation"]] = relationship( + "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + ) - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def model_config(self): @@ -543,23 +566,24 @@ class Conversation(Base): if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - if 'model' in override_model_configs: + if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) model_config = app_model_config.to_dict() else: - model_config['configs'] = override_model_configs + model_config["configs"] = override_model_configs else: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() - + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + ) + if not app_model_config: raise ValueError("app config not found") model_config = app_model_config.to_dict() - model_config['model_id'] = self.model_id - model_config['provider'] = self.model_provider + model_config["model_id"] = self.model_id + model_config["provider"] = self.model_provider return model_config @@ -572,7 +596,7 @@ class Conversation(Base): if first_message: return first_message.query else: - return '' + return "" @property def annotated(self): @@ -588,31 +612,51 @@ class Conversation(Base): @property def user_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def admin_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def first_message(self): @@ -646,33 +690,33 @@ class Conversation(Base): class Message(Base): - __tablename__ = 'messages' + __tablename__ = "messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_pkey'), - db.Index('message_app_id_idx', 'app_id', 'created_at'), - db.Index('message_conversation_id_idx', 'conversation_id'), - db.Index('message_end_user_idx', 'app_id', 'from_source', 'from_end_user_id'), - db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), - db.Index('message_workflow_run_id_idx', 'conversation_id', 'workflow_run_id') + db.PrimaryKeyConstraint("id", name="message_pkey"), + db.Index("message_app_id_idx", "app_id", "created_at"), + db.Index("message_conversation_id_idx", "conversation_id"), + db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), + db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), + db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) inputs: Mapped[str] = mapped_column(db.JSON) query: Mapped[str] = mapped_column(db.Text, nullable=False) message: Mapped[str] = mapped_column(db.JSON, nullable=False) - message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) answer = db.Column(db.Text, nullable=False) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) @@ -682,9 +726,9 @@ class Message(Base): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @property @@ -692,7 +736,7 @@ class Message(Base): if not self.answer: return self.answer - pattern = r'\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)' + pattern = r"\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)" matches = re.findall(pattern, self.answer) if not matches: @@ -708,9 +752,9 @@ class Message(Base): re_sign_file_url_answer = self.answer for url in urls: - if 'files/tools' in url: + if "files/tools" in url: # get tool file id - tool_file_id_pattern = r'\/files\/tools\/([\.\w-]+)?\?timestamp=' + tool_file_id_pattern = r"\/files\/tools\/([\.\w-]+)?\?timestamp=" result = re.search(tool_file_id_pattern, url) if not result: continue @@ -718,25 +762,24 @@ class Message(Base): tool_file_id = result.group(1) # get extension - if '.' in tool_file_id: - split_result = tool_file_id.split('.') - extension = f'.{split_result[-1]}' + if "." in tool_file_id: + split_result = tool_file_id.split(".") + extension = f".{split_result[-1]}" if len(extension) > 10: - extension = '.bin' + extension = ".bin" tool_file_id = split_result[0] else: - extension = '.bin' + extension = ".bin" if not tool_file_id: continue sign_url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, - extension=extension + tool_file_id=tool_file_id, extension=extension ) else: # get upload file id - upload_file_id_pattern = r'\/files\/([\w-]+)\/image-preview?\?timestamp=' + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" result = re.search(upload_file_id_pattern, url) if not result: continue @@ -754,14 +797,20 @@ class Message(Base): @property def user_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'user').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .first() + ) return feedback @property def admin_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'admin').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .first() + ) return feedback @property @@ -776,11 +825,15 @@ class Message(Base): @property def annotation_hit_history(self): - annotation_history = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.message_id == self.id).first()) + annotation_history = ( + db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + ) if annotation_history: - annotation = (db.session.query(MessageAnnotation). - filter(MessageAnnotation.id == annotation_history.annotation_id).first()) + annotation = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id == annotation_history.annotation_id) + .first() + ) return annotation return None @@ -788,8 +841,9 @@ class Message(Base): def app_model_config(self): conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() if conversation: - return db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id).first() + return ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() + ) return None @@ -803,13 +857,21 @@ class Message(Base): @property def agent_thoughts(self): - return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ - .order_by(MessageAgentThought.position.asc()).all() + return ( + db.session.query(MessageAgentThought) + .filter(MessageAgentThought.message_id == self.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) @property def retriever_resources(self): - return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ - .order_by(DatasetRetrieverResource.position.asc()).all() + return ( + db.session.query(DatasetRetrieverResource) + .filter(DatasetRetrieverResource.message_id == self.id) + .order_by(DatasetRetrieverResource.position.asc()) + .all() + ) @property def message_files(self): @@ -822,39 +884,39 @@ class Message(Base): files = [] for message_file in message_files: url = message_file.url - if message_file.type == 'image': - if message_file.transfer_method == 'local_file': - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == message_file.upload_file_id - ).first()) - - url = UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=True + if message_file.type == "image": + if message_file.transfer_method == "local_file": + upload_file = ( + db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first() ) - if message_file.transfer_method == 'tool_file': + + url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True) + if message_file.transfer_method == "tool_file": # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=tool_file_id, extension=extension) + url = ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=tool_file_id, extension=extension + ) - files.append({ - 'id': message_file.id, - 'type': message_file.type, - 'url': url, - 'belongs_to': message_file.belongs_to if message_file.belongs_to else 'user' - }) + files.append( + { + "id": message_file.id, + "type": message_file.type, + "url": url, + "belongs_to": message_file.belongs_to or "user", + } + ) return files @@ -862,64 +924,65 @@ class Message(Base): def workflow_run(self): if self.workflow_run_id: from .workflow import WorkflowRun + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None def to_dict(self) -> dict: return { - 'id': self.id, - 'app_id': self.app_id, - 'conversation_id': self.conversation_id, - 'inputs': self.inputs, - 'query': self.query, - 'message': self.message, - 'answer': self.answer, - 'status': self.status, - 'error': self.error, - 'message_metadata': self.message_metadata_dict, - 'from_source': self.from_source, - 'from_end_user_id': self.from_end_user_id, - 'from_account_id': self.from_account_id, - 'created_at': self.created_at.isoformat(), - 'updated_at': self.updated_at.isoformat(), - 'agent_based': self.agent_based, - 'workflow_run_id': self.workflow_run_id + "id": self.id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "inputs": self.inputs, + "query": self.query, + "message": self.message, + "answer": self.answer, + "status": self.status, + "error": self.error, + "message_metadata": self.message_metadata_dict, + "from_source": self.from_source, + "from_end_user_id": self.from_end_user_id, + "from_account_id": self.from_account_id, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "agent_based": self.agent_based, + "workflow_run_id": self.workflow_run_id, } @classmethod def from_dict(cls, data: dict): return cls( - id=data['id'], - app_id=data['app_id'], - conversation_id=data['conversation_id'], - inputs=data['inputs'], - query=data['query'], - message=data['message'], - answer=data['answer'], - status=data['status'], - error=data['error'], - message_metadata=json.dumps(data['message_metadata']), - from_source=data['from_source'], - from_end_user_id=data['from_end_user_id'], - from_account_id=data['from_account_id'], - created_at=data['created_at'], - updated_at=data['updated_at'], - agent_based=data['agent_based'], - workflow_run_id=data['workflow_run_id'] + id=data["id"], + app_id=data["app_id"], + conversation_id=data["conversation_id"], + inputs=data["inputs"], + query=data["query"], + message=data["message"], + answer=data["answer"], + status=data["status"], + error=data["error"], + message_metadata=json.dumps(data["message_metadata"]), + from_source=data["from_source"], + from_end_user_id=data["from_end_user_id"], + from_account_id=data["from_account_id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + agent_based=data["agent_based"], + workflow_run_id=data["workflow_run_id"], ) class MessageFeedback(db.Model): - __tablename__ = 'message_feedbacks' + __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_feedback_pkey'), - db.Index('message_feedback_app_idx', 'app_id'), - db.Index('message_feedback_message_idx', 'message_id', 'from_source'), - db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') + db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + db.Index("message_feedback_app_idx", "app_id"), + db.Index("message_feedback_message_idx", "message_id", "from_source"), + db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) @@ -928,8 +991,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def from_account(self): @@ -938,14 +1001,14 @@ class MessageFeedback(db.Model): class MessageFile(Base): - __tablename__ = 'message_files' + __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_file_pkey'), - db.Index('message_file_message_idx', 'message_id'), - db.Index('message_file_created_by_idx', 'created_by') + db.PrimaryKeyConstraint("id", name="message_file_pkey"), + db.Index("message_file_message_idx", "message_id"), + db.Index("message_file_created_by_idx", "created_by"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -954,28 +1017,28 @@ class MessageFile(Base): upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageAnnotation(Base): - __tablename__ = 'message_annotations' + __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), - db.Index('message_annotation_app_idx', 'app_id'), - db.Index('message_annotation_conversation_idx', 'conversation_id'), - db.Index('message_annotation_message_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + db.Index("message_annotation_app_idx", "app_id"), + db.Index("message_annotation_conversation_idx", "conversation_id"), + db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) + conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def account(self): @@ -989,32 +1052,35 @@ class MessageAnnotation(Base): class AppAnnotationHitHistory(db.Model): - __tablename__ = 'app_annotation_hit_histories' + __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'), - db.Index('app_annotation_hit_histories_app_idx', 'app_id'), - db.Index('app_annotation_hit_histories_account_idx', 'account_id'), - db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'), - db.Index('app_annotation_hit_histories_message_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + db.Index("app_annotation_hit_histories_app_idx", "app_id"), + db.Index("app_annotation_hit_histories_account_idx", "account_id"), + db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) annotation_id = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - score = db.Column(Float, nullable=False, server_default=db.text('0')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) annotation_content = db.Column(db.Text, nullable=False) @property def account(self): - account = (db.session.query(Account) - .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) - .filter(MessageAnnotation.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) + .filter(MessageAnnotation.id == self.annotation_id) + .first() + ) return account @property @@ -1024,89 +1090,99 @@ class AppAnnotationHitHistory(db.Model): class AppAnnotationSetting(db.Model): - __tablename__ = 'app_annotation_settings' + __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'), - db.Index('app_annotation_settings_app_idx', 'app_id') + db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) + score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def created_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def updated_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = (db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == self.collection_binding_id).first()) + + collection_binding_detail = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .first() + ) return collection_binding_detail class OperationLog(db.Model): - __tablename__ = 'operation_logs' + __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='operation_log_pkey'), - db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') + db.PrimaryKeyConstraint("id", name="operation_log_pkey"), + db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class EndUser(UserMixin, db.Model): - __tablename__ = 'end_users' + __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='end_user_pkey'), - db.Index('end_user_session_id_idx', 'session_id', 'type'), - db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), + db.PrimaryKeyConstraint("id", name="end_user_pkey"), + db.Index("end_user_session_id_idx", "session_id", "type"), + db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) - is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class Site(db.Model): - __tablename__ = 'sites' + __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='site_pkey'), - db.Index('site_app_id_idx', 'app_id'), - db.Index('site_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="site_pkey"), + db.Index("site_app_id_idx", "app_id"), + db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) icon_type = db.Column(db.String(255), nullable=True) @@ -1115,20 +1191,20 @@ class Site(db.Model): description = db.Column(db.Text) default_language = db.Column(db.String(255), nullable=False) chat_color_theme = db.Column(db.String(255)) - chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) copyright = db.Column(db.String(255)) privacy_policy = db.Column(db.String(255)) - show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) custom_disclaimer = db.Column(db.String(255), nullable=True) customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) - prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) @staticmethod @@ -1142,26 +1218,25 @@ class Site(db.Model): @property def app_base_url(self): - return ( - dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip('/')) + return dify_config.APP_WEB_URL or request.url_root.rstrip("/") class ApiToken(db.Model): - __tablename__ = 'api_tokens' + __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_token_pkey'), - db.Index('api_token_app_id_type_idx', 'app_id', 'type'), - db.Index('api_token_token_idx', 'token', 'type'), - db.Index('api_token_tenant_idx', 'tenant_id', 'type') + db.PrimaryKeyConstraint("id", name="api_token_pkey"), + db.Index("api_token_app_id_type_idx", "app_id", "type"), + db.Index("api_token_token_idx", "token", "type"), + db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @staticmethod def generate_api_key(prefix, n): @@ -1174,13 +1249,13 @@ class ApiToken(db.Model): class UploadFile(db.Model): - __tablename__ = 'upload_files' + __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='upload_file_pkey'), - db.Index('upload_file_tenant_idx', 'tenant_id') + db.PrimaryKeyConstraint("id", name="upload_file_pkey"), + db.Index("upload_file_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) storage_type = db.Column(db.String(255), nullable=False) key = db.Column(db.String(255), nullable=False) @@ -1190,38 +1265,38 @@ class UploadFile(db.Model): mime_type = db.Column(db.String(255), nullable=True) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + used = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by = db.Column(StringUUID, nullable=True) used_at = db.Column(db.DateTime, nullable=True) hash = db.Column(db.String(255), nullable=True) class ApiRequest(db.Model): - __tablename__ = 'api_requests' + __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_request_pkey'), - db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') + db.PrimaryKeyConstraint("id", name="api_request_pkey"), + db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False) path = db.Column(db.String(255), nullable=False) request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageChain(db.Model): - __tablename__ = 'message_chains' + __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_chain_pkey'), - db.Index('message_chain_message_id_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_chain_pkey"), + db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) input = db.Column(db.Text, nullable=True) @@ -1230,14 +1305,14 @@ class MessageChain(db.Model): class MessageAgentThought(db.Model): - __tablename__ = 'message_agent_thoughts' + __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_agent_thought_pkey'), - db.Index('message_agent_thought_message_id_idx', 'message_id'), - db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), + db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + db.Index("message_agent_thought_message_id_idx", "message_id"), + db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) message_chain_id = db.Column(StringUUID, nullable=True) position = db.Column(db.Integer, nullable=False) @@ -1252,12 +1327,12 @@ class MessageAgentThought(db.Model): message = db.Column(db.Text, nullable=True) message_token = db.Column(db.Integer, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) answer_token = db.Column(db.Integer, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) tokens = db.Column(db.Integer, nullable=True) total_price = db.Column(db.Numeric, nullable=True) currency = db.Column(db.String, nullable=True) @@ -1314,9 +1389,7 @@ class MessageAgentThought(db.Model): result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: return {} @@ -1337,22 +1410,20 @@ class MessageAgentThought(db.Model): result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) class DatasetRetrieverResource(db.Model): - __tablename__ = 'dataset_retriever_resources' + __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'), - db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -1373,57 +1444,57 @@ class DatasetRetrieverResource(db.Model): class Tag(db.Model): - __tablename__ = 'tags' + __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_pkey'), - db.Index('tag_type_idx', 'type'), - db.Index('tag_name_idx', 'name'), + db.PrimaryKeyConstraint("id", name="tag_pkey"), + db.Index("tag_type_idx", "type"), + db.Index("tag_name_idx", "name"), ) - TAG_TYPE_LIST = ['knowledge', 'app'] + TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TagBinding(db.Model): - __tablename__ = 'tag_bindings' + __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_binding_pkey'), - db.Index('tag_bind_target_id_idx', 'target_id'), - db.Index('tag_bind_tag_id_idx', 'tag_id'), + db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + db.Index("tag_bind_target_id_idx", "target_id"), + db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TraceAppConfig(db.Model): - __tablename__ = 'trace_app_config' + __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tracing_app_config_pkey'), - db.Index('trace_app_config_app_id_idx', 'app_id'), + db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) - is_active = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): - return self.tracing_config if self.tracing_config else {} + return self.tracing_config or {} @property def tracing_config_str(self): @@ -1431,11 +1502,11 @@ class TraceAppConfig(db.Model): def to_dict(self): return { - 'id': self.id, - 'app_id': self.app_id, - 'tracing_provider': self.tracing_provider, - 'tracing_config': self.tracing_config_dict, + "id": self.id, + "app_id": self.app_id, + "tracing_provider": self.tracing_provider, + "tracing_config": self.tracing_config_dict, "is_active": self.is_active, - "created_at": self.created_at.__str__() if self.created_at else None, - 'updated_at': self.updated_at.__str__() if self.updated_at else None, + "created_at": str(self.created_at) if self.created_at else None, + "updated_at": str(self.updated_at) if self.updated_at else None, } diff --git a/api/models/provider.py b/api/models/provider.py index 5d92ee6eb6..644915e781 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,8 +6,8 @@ from .types import StringUUID class ProviderType(Enum): - CUSTOM = 'custom' - SYSTEM = 'system' + CUSTOM = "custom" + SYSTEM = "system" @staticmethod def value_of(value): @@ -18,13 +18,13 @@ class ProviderType(Enum): class ProviderQuotaType(Enum): - PAID = 'paid' + PAID = "paid" """hosted paid quota""" - FREE = 'free' + FREE = "free" """third-party free quota""" - TRIAL = 'trial' + TRIAL = "trial" """hosted trial quota""" @staticmethod @@ -39,36 +39,42 @@ class Provider(db.Model): """ Provider model representing the API providers and their configurations. """ - __tablename__ = 'providers' + + __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_pkey'), - db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + db.PrimaryKeyConstraint("id", name="provider_pkey"), + db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used = db.Column(db.DateTime, nullable=True) quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def __repr__(self): - return f"" + return ( + f"" + ) @property def token_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_config is not None @property @@ -86,118 +92,123 @@ class ProviderModel(db.Model): """ Provider model representing the API provider_models and their configurations. """ - __tablename__ = 'provider_models' + + __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_pkey'), - db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + db.PrimaryKeyConstraint("id", name="provider_model_pkey"), + db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantDefaultModel(db.Model): - __tablename__ = 'tenant_default_models' + __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), - db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPreferredModelProvider(db.Model): - __tablename__ = 'tenant_preferred_model_providers' + __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), - db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderOrder(db.Model): - __tablename__ = 'provider_orders' + __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_order_pkey'), - db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="provider_order_pkey"), + db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) + quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) currency = db.Column(db.String(40)) total_amount = db.Column(db.Integer) payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderModelSetting(db.Model): """ Provider model settings for record the model enabled status and load balancing status. """ - __tablename__ = 'provider_model_settings' + + __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), - db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class LoadBalancingModelConfig(db.Model): """ Configurations for load balancing models. """ - __tablename__ = 'load_balancing_model_configs' + + __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), - db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/source.py b/api/models/source.py index adc00028be..07695f06e6 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -8,48 +8,48 @@ from .types import StringUUID class DataSourceOauthBinding(db.Model): - __tablename__ = 'data_source_oauth_bindings' + __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='source_binding_pkey'), - db.Index('source_binding_tenant_id_idx', 'tenant_id'), - db.Index('source_info_idx', "source_info", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="source_binding_pkey"), + db.Index("source_binding_tenant_id_idx", "tenant_id"), + db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(db.Model): - __tablename__ = 'data_source_api_key_auth_bindings' + __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), - db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), - db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), + db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'category': self.category, - 'provider': self.provider, - 'credentials': json.loads(self.credentials), - 'created_at': self.created_at.timestamp(), - 'updated_at': self.updated_at.timestamp(), - 'disabled': self.disabled + "id": self.id, + "tenant_id": self.tenant_id, + "category": self.category, + "provider": self.provider, + "credentials": json.loads(self.credentials), + "created_at": self.created_at.timestamp(), + "updated_at": self.updated_at.timestamp(), + "disabled": self.disabled, } diff --git a/api/models/task.py b/api/models/task.py index 618d831d8e..57b147c78d 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -8,15 +8,18 @@ from extensions.ext_database import db class CeleryTask(db.Model): """Task result/status.""" - __tablename__ = 'celery_taskmeta' + __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence('task_id_sequence'), - primary_key=True, autoincrement=True) + id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = db.Column(db.String(155), unique=True) status = db.Column(db.String(50), default=states.PENDING) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column( + db.DateTime, + default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + nullable=True, + ) traceback = db.Column(db.Text, nullable=True) name = db.Column(db.String(155), nullable=True) args = db.Column(db.LargeBinary, nullable=True) @@ -29,11 +32,9 @@ class CeleryTask(db.Model): class CeleryTaskSet(db.Model): """TaskSet result.""" - __tablename__ = 'celery_tasksetmeta' + __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence('taskset_id_sequence'), - autoincrement=True, primary_key=True) + id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) diff --git a/api/models/tool.py b/api/models/tool.py index 79a70c6b1f..a81bb65174 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -7,7 +7,7 @@ from .types import StringUUID class ToolProviderName(Enum): - SERPAPI = 'serpapi' + SERPAPI = "serpapi" @staticmethod def value_of(value): @@ -18,25 +18,25 @@ class ToolProviderName(Enum): class ToolProvider(db.Model): - __tablename__ = 'tool_providers' + __tablename__ = "tool_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_provider_pkey'), - db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), + db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) tool_name = db.Column(db.String(40), nullable=False) encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_credentials is not None @property diff --git a/api/models/tools.py b/api/models/tools.py index 27d1001740..485e16b228 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -16,15 +16,16 @@ class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. """ - __tablename__ = 'tool_builtin_providers' + + __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), # one tenant can only have one tool provider with the same name - db.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) # id of the tool provider - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -33,24 +34,30 @@ class BuiltinToolProvider(Base): provider: Mapped[str] = mapped_column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) @property def credentials(self) -> dict: return json.loads(self.encrypted_credentials) + class ApiToolProvider(Base): """ The table stores the api providers. """ - __tablename__ = 'tool_api_providers' + + __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_api_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider name = db.Column(db.String(40), nullable=False) # icon @@ -73,21 +80,21 @@ class ApiToolProvider(Base): # custom_disclaimer custom_disclaimer = db.Column(db.String(255), nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] - + @property def credentials(self) -> dict: return json.loads(self.credentials_str) - + @property def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @@ -96,17 +103,19 @@ class ApiToolProvider(Base): def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + class ToolLabelBinding(Base): """ The table stores the labels for tools. """ - __tablename__ = 'tool_label_bindings' + + __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), - db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), + db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) # tool type @@ -114,28 +123,30 @@ class ToolLabelBinding(Base): # label name label_name: Mapped[str] = mapped_column(db.String(40), nullable=False) + class WorkflowToolProvider(Base): """ The table stores the workflow providers. """ - __tablename__ = 'tool_workflow_providers' + + __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), + db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider name: Mapped[str] = mapped_column(db.String(40), nullable=False) # label of the workflow provider - label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='') + label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") # icon icon: Mapped[str] = mapped_column(db.String(255), nullable=False) # app id of the workflow provider app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # version of the workflow provider - version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='') + version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -143,12 +154,16 @@ class WorkflowToolProvider(Base): # description of the provider description: Mapped[str] = mapped_column(db.Text, nullable=False) # parameter configuration - parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default='[]') + parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default='') + privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) @property def user(self) -> Account | None: @@ -157,28 +172,25 @@ class WorkflowToolProvider(Base): @property def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() - + @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(**config) - for config in json.loads(self.parameter_configuration) - ] - + return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + @property def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke """ - __tablename__ = "tool_model_invokes" - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), - ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + __tablename__ = "tool_model_invokes" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -196,29 +208,31 @@ class ToolModelInvoke(db.Model): # invoke response model_response = db.Column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + class ToolConversationVariables(db.Model): """ store the conversation variables from tool invoke """ + __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey'), + db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index('user_id_idx', 'user_id'), - db.Index('conversation_id_idx', 'conversation_id'), + db.Index("user_id_idx", "user_id"), + db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -228,25 +242,27 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def variables(self) -> dict: return json.loads(self.variables_str) + class ToolFile(Base): """ store the file created by agent """ + __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_file_pkey'), + db.PrimaryKeyConstraint("id", name="tool_file_pkey"), # add index for conversation_id - db.Index('tool_file_conversation_id_idx', 'conversation_id'), + db.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id @@ -259,4 +275,3 @@ class ToolFile(Base): mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) # original url original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) - \ No newline at end of file diff --git a/api/models/types.py b/api/models/types.py index 1614ec2018..cb6773e70c 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -9,13 +9,13 @@ class StringUUID(TypeDecorator): def process_bind_param(self, value, dialect): if value is None: return value - elif dialect.name == 'postgresql': + elif dialect.name == "postgresql": return str(value) else: return value.hex def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': + if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) @@ -23,4 +23,4 @@ class StringUUID(TypeDecorator): def process_result_value(self, value, dialect): if value is None: return value - return str(value) \ No newline at end of file + return str(value) diff --git a/api/models/web.py b/api/models/web.py index 0e901d5f84..bc088c185d 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,3 @@ - from extensions.ext_database import db from .model import Message @@ -6,18 +5,18 @@ from .types import StringUUID class SavedMessage(db.Model): - __tablename__ = 'saved_messages' + __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='saved_message_pkey'), - db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="saved_message_pkey"), + db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def message(self): @@ -25,15 +24,15 @@ class SavedMessage(db.Model): class PinnedConversation(db.Model): - __tablename__ = 'pinned_conversations' + __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), - db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/workflow.py b/api/models/workflow.py index e78b5666bc..9c93ea4cea 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,11 +22,12 @@ class CreatedByRole(Enum): """ Created By Role Enum """ - ACCOUNT = 'account' - END_USER = 'end_user' + + ACCOUNT = "account" + END_USER = "end_user" @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': + def value_of(cls, value: str) -> "CreatedByRole": """ Get value of given mode. @@ -36,18 +37,19 @@ class CreatedByRole(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid created by role value {value}') + raise ValueError(f"invalid created by role value {value}") class WorkflowType(Enum): """ Workflow Type Enum """ - WORKFLOW = 'workflow' - CHAT = 'chat' + + WORKFLOW = "workflow" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'WorkflowType': + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -57,10 +59,10 @@ class WorkflowType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow type value {value}') + raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -68,6 +70,7 @@ class WorkflowType(Enum): :return: workflow type """ from models.model import AppMode + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT @@ -105,13 +108,13 @@ class Workflow(db.Model): - updated_at (timestamp) `optional` Last update time """ - __tablename__ = 'workflows' + __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), + db.PrimaryKeyConstraint("id", name="workflow_pkey"), + db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = db.Column(StringUUID, nullable=False) type: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -119,15 +122,31 @@ class Workflow(db.Model): graph: Mapped[str] = db.Column(db.Text) features: Mapped[str] = db.Column(db.Text) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) updated_by: Mapped[str] = db.Column(StringUUID) updated_at: Mapped[datetime] = db.Column(db.DateTime) - _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') - _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + _environment_variables: Mapped[str] = db.Column( + "environment_variables", db.Text, nullable=False, server_default="{}" + ) + _conversation_variables: Mapped[str] = db.Column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) - def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, - features: str, created_by: str, environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable]): + def __init__( + self, + *, + tenant_id: str, + app_id: str, + type: str, + version: str, + graph: str, + features: str, + created_by: str, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ): self.tenant_id = tenant_id self.app_id = app_id self.type = type @@ -160,22 +179,20 @@ class Workflow(db.Model): return [] graph_dict = self.graph_dict - if 'nodes' not in graph_dict: + if "nodes" not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) + start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) if not start_node: return [] # get user_input_form from start node - variables = start_node.get('data', {}).get('variables', []) + variables = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] for variable in variables: - old_structure_variables.append({ - variable['type']: variable - }) + old_structure_variables.append({variable["type"]: variable}) return old_structure_variables @@ -188,25 +205,24 @@ class Workflow(db.Model): :return: hash """ - entity = { - 'graph': self.graph_dict, - 'features': self.features_dict - } + entity = {"graph": self.graph_dict, "features": self.features_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property def tool_published(self) -> bool: from models.tools import WorkflowToolProvider - return db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.app_id == self.app_id - ).first() is not None + + return ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.app_id == self.app_id).first() + is not None + ) @property def environment_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: - self._environment_variables = '{}' + self._environment_variables = "{}" tenant_id = contexts.tenant_id.get() @@ -215,9 +231,7 @@ class Workflow(db.Model): # decrypt secret variables value decrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -230,19 +244,18 @@ class Workflow(db.Model): value = list(value) if any(var for var in value if not var.id): - raise ValueError('environment variable require a unique id') + raise ValueError("environment variable require a unique id") - # Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). + # Compare inputs and origin variables, + # if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). origin_variables_dictionary = {var.id: var for var in self.environment_variables} for i, variable in enumerate(value): if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: - value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name}) + value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value encrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -256,15 +269,15 @@ class Workflow(db.Model): def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: environment_variables = list(self.environment_variables) environment_variables = [ - v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''}) + v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] result = { - 'graph': self.graph_dict, - 'features': self.features_dict, - 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], - 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], + "graph": self.graph_dict, + "features": self.features_dict, + "environment_variables": [var.model_dump(mode="json") for var in environment_variables], + "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], } return result @@ -272,7 +285,7 @@ class Workflow(db.Model): def conversation_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: - self._conversation_variables = '{}' + self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] @@ -290,11 +303,12 @@ class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum """ - DEBUGGING = 'debugging' - APP_RUN = 'app-run' + + DEBUGGING = "debugging" + APP_RUN = "app-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom": """ Get value of given mode. @@ -304,20 +318,21 @@ class WorkflowRunTriggeredFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run triggered from value {value}') + raise ValueError(f"invalid workflow run triggered from value {value}") class WorkflowRunStatus(Enum): """ Workflow Run Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' - STOPPED = 'stopped' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunStatus': + def value_of(cls, value: str) -> "WorkflowRunStatus": """ Get value of given mode. @@ -327,7 +342,7 @@ class WorkflowRunStatus(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run status value {value}') + raise ValueError(f"invalid workflow run status value {value}") class WorkflowRun(db.Model): @@ -368,14 +383,14 @@ class WorkflowRun(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_runs' + __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), - db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), - db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'), + db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) sequence_number = db.Column(db.Integer, nullable=False) @@ -388,26 +403,25 @@ class WorkflowRun(db.Model): status = db.Column(db.String(255), nullable=False) outputs = db.Column(db.Text) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - total_steps = db.Column(db.Integer, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def graph_dict(self): @@ -422,12 +436,12 @@ class WorkflowRun(db.Model): return json.loads(self.outputs) if self.outputs else None @property - def message(self) -> Optional['Message']: + def message(self) -> Optional["Message"]: from models.model import Message - return db.session.query(Message).filter( - Message.app_id == self.app_id, - Message.workflow_run_id == self.id - ).first() + + return ( + db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + ) @property def workflow(self): @@ -435,51 +449,51 @@ class WorkflowRun(db.Model): def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'app_id': self.app_id, - 'sequence_number': self.sequence_number, - 'workflow_id': self.workflow_id, - 'type': self.type, - 'triggered_from': self.triggered_from, - 'version': self.version, - 'graph': self.graph_dict, - 'inputs': self.inputs_dict, - 'status': self.status, - 'outputs': self.outputs_dict, - 'error': self.error, - 'elapsed_time': self.elapsed_time, - 'total_tokens': self.total_tokens, - 'total_steps': self.total_steps, - 'created_by_role': self.created_by_role, - 'created_by': self.created_by, - 'created_at': self.created_at, - 'finished_at': self.finished_at, + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "sequence_number": self.sequence_number, + "workflow_id": self.workflow_id, + "type": self.type, + "triggered_from": self.triggered_from, + "version": self.version, + "graph": self.graph_dict, + "inputs": self.inputs_dict, + "status": self.status, + "outputs": self.outputs_dict, + "error": self.error, + "elapsed_time": self.elapsed_time, + "total_tokens": self.total_tokens, + "total_steps": self.total_steps, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + "finished_at": self.finished_at, } @classmethod - def from_dict(cls, data: dict) -> 'WorkflowRun': + def from_dict(cls, data: dict) -> "WorkflowRun": return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - app_id=data.get('app_id'), - sequence_number=data.get('sequence_number'), - workflow_id=data.get('workflow_id'), - type=data.get('type'), - triggered_from=data.get('triggered_from'), - version=data.get('version'), - graph=json.dumps(data.get('graph')), - inputs=json.dumps(data.get('inputs')), - status=data.get('status'), - outputs=json.dumps(data.get('outputs')), - error=data.get('error'), - elapsed_time=data.get('elapsed_time'), - total_tokens=data.get('total_tokens'), - total_steps=data.get('total_steps'), - created_by_role=data.get('created_by_role'), - created_by=data.get('created_by'), - created_at=data.get('created_at'), - finished_at=data.get('finished_at'), + id=data.get("id"), + tenant_id=data.get("tenant_id"), + app_id=data.get("app_id"), + sequence_number=data.get("sequence_number"), + workflow_id=data.get("workflow_id"), + type=data.get("type"), + triggered_from=data.get("triggered_from"), + version=data.get("version"), + graph=json.dumps(data.get("graph")), + inputs=json.dumps(data.get("inputs")), + status=data.get("status"), + outputs=json.dumps(data.get("outputs")), + error=data.get("error"), + elapsed_time=data.get("elapsed_time"), + total_tokens=data.get("total_tokens"), + total_steps=data.get("total_steps"), + created_by_role=data.get("created_by_role"), + created_by=data.get("created_by"), + created_at=data.get("created_at"), + finished_at=data.get("finished_at"), ) @@ -487,11 +501,12 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): """ Workflow Node Execution Triggered From Enum """ - SINGLE_STEP = 'single-step' - WORKFLOW_RUN = 'workflow-run' + + SINGLE_STEP = "single-step" + WORKFLOW_RUN = "workflow-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": """ Get value of given mode. @@ -501,19 +516,20 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution triggered from value {value}') + raise ValueError(f"invalid workflow node execution triggered from value {value}") class WorkflowNodeExecutionStatus(Enum): """ Workflow Node Execution Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": """ Get value of given mode. @@ -523,7 +539,7 @@ class WorkflowNodeExecutionStatus(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution status value {value}') + raise ValueError(f"invalid workflow node execution status value {value}") class WorkflowNodeExecution(db.Model): @@ -574,18 +590,31 @@ class WorkflowNodeExecution(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_node_executions' + __tablename__ = "workflow_node_executions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), - db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'workflow_run_id'), - db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'node_id'), - db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'node_execution_id'), + db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + db.Index( + "workflow_node_execution_workflow_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "workflow_run_id", + ), + db.Index( + "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" + ), + db.Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -602,9 +631,9 @@ class WorkflowNodeExecution(db.Model): outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -612,15 +641,14 @@ class WorkflowNodeExecution(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def inputs_dict(self): @@ -641,15 +669,17 @@ class WorkflowNodeExecution(db.Model): @property def extras(self): from core.tools.tool_manager import ToolManager + extras = {} if self.execution_metadata_dict: from core.workflow.entities.node_entities import NodeType - if self.node_type == NodeType.TOOL.value and 'tool_info' in self.execution_metadata_dict: - tool_info = self.execution_metadata_dict['tool_info'] - extras['icon'] = ToolManager.get_tool_icon( + + if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + tool_info = self.execution_metadata_dict["tool_info"] + extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, - provider_type=tool_info['provider_type'], - provider_id=tool_info['provider_id'] + provider_type=tool_info["provider_type"], + provider_id=tool_info["provider_id"], ) return extras @@ -659,12 +689,13 @@ class WorkflowAppLogCreatedFrom(Enum): """ Workflow App Log Created From Enum """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - INSTALLED_APP = 'installed-app' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -674,7 +705,7 @@ class WorkflowAppLogCreatedFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow app log created from value {value}') + raise ValueError(f"invalid workflow app log created from value {value}") class WorkflowAppLog(db.Model): @@ -706,13 +737,13 @@ class WorkflowAppLog(db.Model): - created_at (timestamp) Creation time """ - __tablename__ = 'workflow_app_logs' + __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), - db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -720,7 +751,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def workflow_run(self): @@ -729,26 +760,27 @@ class WorkflowAppLog(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None class ConversationVariable(db.Model): - __tablename__ = 'workflow_conversation_variables' + __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: self.id = id @@ -757,7 +789,7 @@ class ConversationVariable(db.Model): self.data = data @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, diff --git a/api/poetry.lock b/api/poetry.lock index 103423e5c7..191db600e4 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -569,13 +569,13 @@ typing-extensions = ">=4.6.0" [[package]] name = "azure-ai-ml" -version = "1.19.0" +version = "1.20.0" description = "Microsoft Azure Machine Learning Client Library for Python" optional = false python-versions = ">=3.7" files = [ - {file = "azure-ai-ml-1.19.0.tar.gz", hash = "sha256:94bb1afbb0497e539ae75455fc4a51b6942b5b68b3a275727ecce6ceb250eff9"}, - {file = "azure_ai_ml-1.19.0-py3-none-any.whl", hash = "sha256:f0385af06efbeae1f83113613e45343508d1288fd2f05857619e7c7d4d4f5302"}, + {file = "azure-ai-ml-1.20.0.tar.gz", hash = "sha256:6432a0da1b7250cb0db5a1c33202e0419935e19ea32d4c2b3220705f8f1d4101"}, + {file = "azure_ai_ml-1.20.0-py3-none-any.whl", hash = "sha256:c7eb3c5ccf82a6ee94403c3e5060763decd38cf03ff2620a4a6577526e605104"}, ] [package.dependencies] @@ -809,17 +809,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.148" +version = "1.35.17" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.148-py3-none-any.whl", hash = "sha256:d63d36e5a34533ba69188d56f96da132730d5e9932c4e11c02d79319cd1afcec"}, - {file = "boto3-1.34.148.tar.gz", hash = "sha256:2058397f0a92c301e3116e9e65fbbc70ea49270c250882d65043d19b7c6e2d17"}, + {file = "boto3-1.35.17-py3-none-any.whl", hash = "sha256:67268aa6c4043e9fdeb4ab3c1e9032f44a6fa168c789af5e351f63f1f8880a2f"}, + {file = "boto3-1.35.17.tar.gz", hash = "sha256:4a32db8793569ee5f13c5bf3efb260193353cb8946bf6426e3c330b61c68e59d"}, ] [package.dependencies] -botocore = ">=1.34.148,<1.35.0" +botocore = ">=1.35.17,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -828,13 +828,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.162" +version = "1.35.17" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, - {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, + {file = "botocore-1.35.17-py3-none-any.whl", hash = "sha256:a93f773ca93139529b5d36730b382dbee63ab4c7f26129aa5c84835255ca999d"}, + {file = "botocore-1.35.17.tar.gz", hash = "sha256:0d35d03ea647b5d464c7f77bdab6fb23ae5d49752b13cf97ab84444518c7b1bd"}, ] [package.dependencies] @@ -843,7 +843,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.21.2)"] +crt = ["awscrt (==0.21.5)"] [[package]] name = "bottleneck" @@ -1049,13 +1049,13 @@ beautifulsoup4 = "*" [[package]] name = "build" -version = "1.2.1" +version = "1.2.2" description = "A simple, correct Python build frontend" optional = false python-versions = ">=3.8" files = [ - {file = "build-1.2.1-py3-none-any.whl", hash = "sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4"}, - {file = "build-1.2.1.tar.gz", hash = "sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d"}, + {file = "build-1.2.2-py3-none-any.whl", hash = "sha256:277ccc71619d98afdd841a0e96ac9fe1593b823af481d3b0cea748e8894e0613"}, + {file = "build-1.2.2.tar.gz", hash = "sha256:119b2fb462adef986483438377a13b2f42064a2a3a4161f24a0cca698a07ac8c"}, ] [package.dependencies] @@ -2241,57 +2241,57 @@ typing_extensions = ">=4.0,<5.0" [[package]] name = "duckdb" -version = "1.0.0" +version = "1.1.0" description = "DuckDB in-process database" optional = false python-versions = ">=3.7.0" files = [ - {file = "duckdb-1.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4a8ce2d1f9e1c23b9bab3ae4ca7997e9822e21563ff8f646992663f66d050211"}, - {file = "duckdb-1.0.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:19797670f20f430196e48d25d082a264b66150c264c1e8eae8e22c64c2c5f3f5"}, - {file = "duckdb-1.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:b71c342090fe117b35d866a91ad6bffce61cd6ff3e0cff4003f93fc1506da0d8"}, - {file = "duckdb-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25dd69f44ad212c35ae2ea736b0e643ea2b70f204b8dff483af1491b0e2a4cec"}, - {file = "duckdb-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8da5f293ecb4f99daa9a9352c5fd1312a6ab02b464653a0c3a25ab7065c45d4d"}, - {file = "duckdb-1.0.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3207936da9967ddbb60644ec291eb934d5819b08169bc35d08b2dedbe7068c60"}, - {file = "duckdb-1.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1128d6c9c33e883b1f5df6b57c1eb46b7ab1baf2650912d77ee769aaa05111f9"}, - {file = "duckdb-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:02310d263474d0ac238646677feff47190ffb82544c018b2ff732a4cb462c6ef"}, - {file = "duckdb-1.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:75586791ab2702719c284157b65ecefe12d0cca9041da474391896ddd9aa71a4"}, - {file = "duckdb-1.0.0-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:83bb415fc7994e641344f3489e40430ce083b78963cb1057bf714ac3a58da3ba"}, - {file = "duckdb-1.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:bee2e0b415074e84c5a2cefd91f6b5ebeb4283e7196ba4ef65175a7cef298b57"}, - {file = "duckdb-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa5a4110d2a499312609544ad0be61e85a5cdad90e5b6d75ad16b300bf075b90"}, - {file = "duckdb-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fa389e6a382d4707b5f3d1bc2087895925ebb92b77e9fe3bfb23c9b98372fdc"}, - {file = "duckdb-1.0.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7ede6f5277dd851f1a4586b0c78dc93f6c26da45e12b23ee0e88c76519cbdbe0"}, - {file = "duckdb-1.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0b88cdbc0d5c3e3d7545a341784dc6cafd90fc035f17b2f04bf1e870c68456e5"}, - {file = "duckdb-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd1693cdd15375156f7fff4745debc14e5c54928589f67b87fb8eace9880c370"}, - {file = "duckdb-1.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:c65a7fe8a8ce21b985356ee3ec0c3d3b3b2234e288e64b4cfb03356dbe6e5583"}, - {file = "duckdb-1.0.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:e5a8eda554379b3a43b07bad00968acc14dd3e518c9fbe8f128b484cf95e3d16"}, - {file = "duckdb-1.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:a1b6acdd54c4a7b43bd7cb584975a1b2ff88ea1a31607a2b734b17960e7d3088"}, - {file = "duckdb-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a677bb1b6a8e7cab4a19874249d8144296e6e39dae38fce66a80f26d15e670df"}, - {file = "duckdb-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:752e9d412b0a2871bf615a2ede54be494c6dc289d076974eefbf3af28129c759"}, - {file = "duckdb-1.0.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3aadb99d098c5e32d00dc09421bc63a47134a6a0de9d7cd6abf21780b678663c"}, - {file = "duckdb-1.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83b7091d4da3e9301c4f9378833f5ffe934fb1ad2b387b439ee067b2c10c8bb0"}, - {file = "duckdb-1.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:6a8058d0148b544694cb5ea331db44f6c2a00a7b03776cc4dd1470735c3d5ff7"}, - {file = "duckdb-1.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e40cb20e5ee19d44bc66ec99969af791702a049079dc5f248c33b1c56af055f4"}, - {file = "duckdb-1.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7bce1bc0de9af9f47328e24e6e7e39da30093179b1c031897c042dd94a59c8e"}, - {file = "duckdb-1.0.0-cp37-cp37m-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8355507f7a04bc0a3666958f4414a58e06141d603e91c0fa5a7c50e49867fb6d"}, - {file = "duckdb-1.0.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:39f1a46f5a45ad2886dc9b02ce5b484f437f90de66c327f86606d9ba4479d475"}, - {file = "duckdb-1.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a6d29ba477b27ae41676b62c8fae8d04ee7cbe458127a44f6049888231ca58fa"}, - {file = "duckdb-1.0.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:1bea713c1925918714328da76e79a1f7651b2b503511498ccf5e007a7e67d49e"}, - {file = "duckdb-1.0.0-cp38-cp38-macosx_12_0_universal2.whl", hash = "sha256:bfe67f3bcf181edbf6f918b8c963eb060e6aa26697d86590da4edc5707205450"}, - {file = "duckdb-1.0.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:dbc6093a75242f002be1d96a6ace3fdf1d002c813e67baff52112e899de9292f"}, - {file = "duckdb-1.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba1881a2b11c507cee18f8fd9ef10100be066fddaa2c20fba1f9a664245cd6d8"}, - {file = "duckdb-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:445d0bb35087c522705c724a75f9f1c13f1eb017305b694d2686218d653c8142"}, - {file = "duckdb-1.0.0-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:224553432e84432ffb9684f33206572477049b371ce68cc313a01e214f2fbdda"}, - {file = "duckdb-1.0.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:d3914032e47c4e76636ad986d466b63fdea65e37be8a6dfc484ed3f462c4fde4"}, - {file = "duckdb-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:af9128a2eb7e1bb50cd2c2020d825fb2946fdad0a2558920cd5411d998999334"}, - {file = "duckdb-1.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dd2659a5dbc0df0de68f617a605bf12fe4da85ba24f67c08730984a0892087e8"}, - {file = "duckdb-1.0.0-cp39-cp39-macosx_12_0_universal2.whl", hash = "sha256:ac5a4afb0bc20725e734e0b2c17e99a274de4801aff0d4e765d276b99dad6d90"}, - {file = "duckdb-1.0.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:2c5a53bee3668d6e84c0536164589d5127b23d298e4c443d83f55e4150fafe61"}, - {file = "duckdb-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b980713244d7708b25ee0a73de0c65f0e5521c47a0e907f5e1b933d79d972ef6"}, - {file = "duckdb-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21cbd4f9fe7b7a56eff96c3f4d6778770dd370469ca2212eddbae5dd63749db5"}, - {file = "duckdb-1.0.0-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ed228167c5d49888c5ef36f6f9cbf65011c2daf9dcb53ea8aa7a041ce567b3e4"}, - {file = "duckdb-1.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:46d8395fbcea7231fd5032a250b673cc99352fef349b718a23dea2c0dd2b8dec"}, - {file = "duckdb-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:6ad1fc1a4d57e7616944166a5f9417bdbca1ea65c490797e3786e3a42e162d8a"}, - {file = "duckdb-1.0.0.tar.gz", hash = "sha256:a2a059b77bc7d5b76ae9d88e267372deff19c291048d59450c431e166233d453"}, + {file = "duckdb-1.1.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5e4cbc408e6e41146dea89b9044dae7356e353db0c96b183e5583ee02bc6ae5d"}, + {file = "duckdb-1.1.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:6370ae27ec8167ccfbefb94f58ad9fdc7bac142399960549d6d367f233189868"}, + {file = "duckdb-1.1.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4e1c3414f7fd01f4810dc8b335deffc91933a159282d65fef11c1286bc0ded04"}, + {file = "duckdb-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6bc2a58689adf5520303c5f68b065b9f980bd31f1366c541b8c7490abaf55cd"}, + {file = "duckdb-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d02be208d2885ca085d4c852b911493b8cdac9d6eae893259da32bd72a437c25"}, + {file = "duckdb-1.1.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:655df442ceebfc6f3fd6c8766e04b60d44dddedfa90275d794f9fab2d3180879"}, + {file = "duckdb-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6e183729bb64be7798ccbfda6283ebf423c869268c25af2b56929e48f763be2f"}, + {file = "duckdb-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:61fb838da51e07ceb0222c4406b059b90e10efcc453c19a3650b73c0112138c4"}, + {file = "duckdb-1.1.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:7807e2f0d3344668e433f0dc1f54bfaddd410589611393e9a7ed56f8dec9514f"}, + {file = "duckdb-1.1.0-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:3da30b7b466f710d52caa1fdc3ef0bf4176ad7f115953cd9f8b0fbf0f723778f"}, + {file = "duckdb-1.1.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:b9b6a77ef0183f561b1fc2945fcc762a71570ffd33fea4e3a855d413ed596fe4"}, + {file = "duckdb-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16243e66a9fd0e64ee265f2634d137adc6593f54ddf3ef55cb8a29e1decf6e54"}, + {file = "duckdb-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42b910a149e00f40a1766dc74fa309d4255b912a5d2fdcc387287658048650f6"}, + {file = "duckdb-1.1.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:47849d546dc4238c0f20e95fe53b621aa5b08684e68fff91fd84a7092be91a17"}, + {file = "duckdb-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:11ec967b67159361ceade34095796a8d19368ea5c30cad988f44896b082b0816"}, + {file = "duckdb-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:510b5885ed6c267b9c0e1e7c6138fdffc2dd6f934a5a95b76da85da127213338"}, + {file = "duckdb-1.1.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:657bc7ac64d5faf069a782ae73afac51ef30ae2e5d0e09ce6a09d03db84ab35e"}, + {file = "duckdb-1.1.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:89f3de8cba57d19b41cd3c47dd06d979bd2a2ffead115480e37afbe72b02896d"}, + {file = "duckdb-1.1.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:f6486323ab20656d22ffa8f3c6e109dde30d0b327b7c831f22ebcfe747f97fb0"}, + {file = "duckdb-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78a4510f82431ee3f14db689fe8727a4a9062c8f2fbb3bcfe3bfad3c1a198004"}, + {file = "duckdb-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64bf2a6e23840d662bd2ac09206a9bd4fa657418884d69e5c352d4456dc70b3c"}, + {file = "duckdb-1.1.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:23fc9aa0af74e3803ed90c8d98280fd5bcac8c940592bf6288e8fd60fb051d00"}, + {file = "duckdb-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1f3aea31341ce400640dd522e4399b941f66df17e39884f446638fe958d6117c"}, + {file = "duckdb-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:3db4ab31c20de4edaef152930836b38e7662cd71370748fdf2c38ba9cf854dc4"}, + {file = "duckdb-1.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3b6b4fe1edfe35f64f403a9f0ab75258cee35abd964356893ee37424174b7e4"}, + {file = "duckdb-1.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aad02f50d5a2020822d1638fc1a9bcf082056f11d2e15ccfc1c1ed4d0f85a3be"}, + {file = "duckdb-1.1.0-cp37-cp37m-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb66e9e7391801928ea134dcab12d2e4c97f2ce0391c603a3e480bbb15830bc8"}, + {file = "duckdb-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:069fb7bca459e31edb32a61f0eea95d7a8a766bef7b8318072563abf8e939593"}, + {file = "duckdb-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e39f9b7b62e64e10d421ff04480290a70129c38067d1a4f600e9212b10542c5a"}, + {file = "duckdb-1.1.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:55ef98bcc7ba745752607f1b926e8d9b7ce32c42c423bbad10c44820aefe23a7"}, + {file = "duckdb-1.1.0-cp38-cp38-macosx_12_0_universal2.whl", hash = "sha256:e2a08175e43b865c1e9611efd18cacd29ddd69093de442b1ebdf312071df7719"}, + {file = "duckdb-1.1.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:0e3644b1f034012d82b9baa12a7ea306fe71dc6623731b28c753c4a617ff9499"}, + {file = "duckdb-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:211a33c1ddb5cc609f75eb43772b0b03b45d2fa89bec107e4715267ca907806a"}, + {file = "duckdb-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e74b6f8a5145abbf7e6c1a2a61f0adbcd493c19b358f524ec9a3cebdf362abb"}, + {file = "duckdb-1.1.0-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:58f1633dd2c5af5088ae2d119418e200855d0699d84f2fae9d46d30f404bcead"}, + {file = "duckdb-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:d18caea926b1e301c29b140418fca697aad728129e269b4f82c2795a184549e1"}, + {file = "duckdb-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:cd9fb1408942411ad360f8414bc3fbf0091c396ca903d947a10f2e31324d5cbd"}, + {file = "duckdb-1.1.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bd11bc899cebf5ff936d1276a2dfb7b7db08aba3bcc42924afeafc2163bddb43"}, + {file = "duckdb-1.1.0-cp39-cp39-macosx_12_0_universal2.whl", hash = "sha256:53825a63193c582a78c152ea53de8d145744ddbeea18f452625a82ebc33eb14a"}, + {file = "duckdb-1.1.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:29dc18087de47563b3859a6b98bbed96e1c96ce5db829646dc3b16a916997e7d"}, + {file = "duckdb-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecb19319883564237a7a03a104dbe7f445e73519bb67108fcab3d19b6b91fe30"}, + {file = "duckdb-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aac2fcabe2d5072c252d0b3087365f431de812d8199705089fb073e4d039d19c"}, + {file = "duckdb-1.1.0-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d89eaaa5df8a57e7d2bc1f4c46493bb1fee319a00155f2015810ad2ace6570ae"}, + {file = "duckdb-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d86a6926313913cd2cc7e08816d3e7f72ba340adf2959279b1a80058be6526d9"}, + {file = "duckdb-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8333f3e85fa2a0f1c222b752c2bd42ea875235ff88492f7bcbb6867d0f644eb"}, + {file = "duckdb-1.1.0.tar.gz", hash = "sha256:b4d4c12b1f98732151bd31377753e0da1a20f6423016d2d097d2e31953ec7c23"}, ] [[package]] @@ -2429,13 +2429,13 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.113.0" +version = "0.114.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.113.0-py3-none-any.whl", hash = "sha256:c8d364485b6361fb643d53920a18d58a696e189abcb901ec03b487e35774c476"}, - {file = "fastapi-0.113.0.tar.gz", hash = "sha256:b7cf9684dc154dfc93f8b718e5850577b529889096518df44defa41e73caf50f"}, + {file = "fastapi-0.114.1-py3-none-any.whl", hash = "sha256:5d4746f6e4b7dff0b4f6b6c6d5445645285f662fe75886e99af7ee2d6b58bb3e"}, + {file = "fastapi-0.114.1.tar.gz", hash = "sha256:1d7bbbeabbaae0acb0c22f0ab0b040f642d3093ca3645f8c876b6f91391861d8"}, ] [package.dependencies] @@ -2524,19 +2524,19 @@ sgmllib3k = "*" [[package]] name = "filelock" -version = "3.15.4" +version = "3.16.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, - {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, + {file = "filelock-3.16.0-py3-none-any.whl", hash = "sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609"}, + {file = "filelock-3.16.0.tar.gz", hash = "sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] -typing = ["typing-extensions (>=4.8)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.1.1)", "pytest (>=8.3.2)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.3)"] +typing = ["typing-extensions (>=4.12.2)"] [[package]] name = "filetype" @@ -3410,69 +3410,77 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] [[package]] name = "greenlet" -version = "3.0.3" +version = "3.1.0" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" files = [ - {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, - {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, - {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, - {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, - {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, - {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, - {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, - {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, - {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, - {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, - {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, - {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, - {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, - {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, - {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, - {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, - {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, - {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, - {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, - {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, - {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, - {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, - {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, - {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, - {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, - {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, - {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, - {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, + {file = "greenlet-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a814dc3100e8a046ff48faeaa909e80cdb358411a3d6dd5293158425c684eda8"}, + {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a771dc64fa44ebe58d65768d869fcfb9060169d203446c1d446e844b62bdfdca"}, + {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0e49a65d25d7350cca2da15aac31b6f67a43d867448babf997fe83c7505f57bc"}, + {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2cd8518eade968bc52262d8c46727cfc0826ff4d552cf0430b8d65aaf50bb91d"}, + {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76dc19e660baea5c38e949455c1181bc018893f25372d10ffe24b3ed7341fb25"}, + {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0a5b1c22c82831f56f2f7ad9bbe4948879762fe0d59833a4a71f16e5fa0f682"}, + {file = "greenlet-3.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:2651dfb006f391bcb240635079a68a261b227a10a08af6349cba834a2141efa1"}, + {file = "greenlet-3.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3e7e6ef1737a819819b1163116ad4b48d06cfdd40352d813bb14436024fcda99"}, + {file = "greenlet-3.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:ffb08f2a1e59d38c7b8b9ac8083c9c8b9875f0955b1e9b9b9a965607a51f8e54"}, + {file = "greenlet-3.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9730929375021ec90f6447bff4f7f5508faef1c02f399a1953870cdb78e0c345"}, + {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713d450cf8e61854de9420fb7eea8ad228df4e27e7d4ed465de98c955d2b3fa6"}, + {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c3446937be153718250fe421da548f973124189f18fe4575a0510b5c928f0cc"}, + {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ddc7bcedeb47187be74208bc652d63d6b20cb24f4e596bd356092d8000da6d6"}, + {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44151d7b81b9391ed759a2f2865bbe623ef00d648fed59363be2bbbd5154656f"}, + {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cea1cca3be76c9483282dc7760ea1cc08a6ecec1f0b6ca0a94ea0d17432da19"}, + {file = "greenlet-3.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:619935a44f414274a2c08c9e74611965650b730eb4efe4b2270f91df5e4adf9a"}, + {file = "greenlet-3.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:221169d31cada333a0c7fd087b957c8f431c1dba202c3a58cf5a3583ed973e9b"}, + {file = "greenlet-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:01059afb9b178606b4b6e92c3e710ea1635597c3537e44da69f4531e111dd5e9"}, + {file = "greenlet-3.1.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:24fc216ec7c8be9becba8b64a98a78f9cd057fd2dc75ae952ca94ed8a893bf27"}, + {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d07c28b85b350564bdff9f51c1c5007dfb2f389385d1bc23288de51134ca303"}, + {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:243a223c96a4246f8a30ea470c440fe9db1f5e444941ee3c3cd79df119b8eebf"}, + {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26811df4dc81271033a7836bc20d12cd30938e6bd2e9437f56fa03da81b0f8fc"}, + {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d86401550b09a55410f32ceb5fe7efcd998bd2dad9e82521713cb148a4a15f"}, + {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:26d9c1c4f1748ccac0bae1dbb465fb1a795a75aba8af8ca871503019f4285e2a"}, + {file = "greenlet-3.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cd468ec62257bb4544989402b19d795d2305eccb06cde5da0eb739b63dc04665"}, + {file = "greenlet-3.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a53dfe8f82b715319e9953330fa5c8708b610d48b5c59f1316337302af5c0811"}, + {file = "greenlet-3.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:28fe80a3eb673b2d5cc3b12eea468a5e5f4603c26aa34d88bf61bba82ceb2f9b"}, + {file = "greenlet-3.1.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:76b3e3976d2a452cba7aa9e453498ac72240d43030fdc6d538a72b87eaff52fd"}, + {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:655b21ffd37a96b1e78cc48bf254f5ea4b5b85efaf9e9e2a526b3c9309d660ca"}, + {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6f4c2027689093775fd58ca2388d58789009116844432d920e9147f91acbe64"}, + {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76e5064fd8e94c3f74d9fd69b02d99e3cdb8fc286ed49a1f10b256e59d0d3a0b"}, + {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a4bf607f690f7987ab3291406e012cd8591a4f77aa54f29b890f9c331e84989"}, + {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:037d9ac99540ace9424cb9ea89f0accfaff4316f149520b4ae293eebc5bded17"}, + {file = "greenlet-3.1.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:90b5bbf05fe3d3ef697103850c2ce3374558f6fe40fd57c9fac1bf14903f50a5"}, + {file = "greenlet-3.1.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:726377bd60081172685c0ff46afbc600d064f01053190e4450857483c4d44484"}, + {file = "greenlet-3.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:d46d5069e2eeda111d6f71970e341f4bd9aeeee92074e649ae263b834286ecc0"}, + {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81eeec4403a7d7684b5812a8aaa626fa23b7d0848edb3a28d2eb3220daddcbd0"}, + {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3dae7492d16e85ea6045fd11cb8e782b63eac8c8d520c3a92c02ac4573b0a6"}, + {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b5ea3664eed571779403858d7cd0a9b0ebf50d57d2cdeafc7748e09ef8cd81a"}, + {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a22f4e26400f7f48faef2d69c20dc055a1f3043d330923f9abe08ea0aecc44df"}, + {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13ff8c8e54a10472ce3b2a2da007f915175192f18e6495bad50486e87c7f6637"}, + {file = "greenlet-3.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f9671e7282d8c6fcabc32c0fb8d7c0ea8894ae85cee89c9aadc2d7129e1a9954"}, + {file = "greenlet-3.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:184258372ae9e1e9bddce6f187967f2e08ecd16906557c4320e3ba88a93438c3"}, + {file = "greenlet-3.1.0-cp37-cp37m-win32.whl", hash = "sha256:a0409bc18a9f85321399c29baf93545152d74a49d92f2f55302f122007cfda00"}, + {file = "greenlet-3.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:9eb4a1d7399b9f3c7ac68ae6baa6be5f9195d1d08c9ddc45ad559aa6b556bce6"}, + {file = "greenlet-3.1.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:a8870983af660798dc1b529e1fd6f1cefd94e45135a32e58bd70edd694540f33"}, + {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfcfb73aed40f550a57ea904629bdaf2e562c68fa1164fa4588e752af6efdc3f"}, + {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9482c2ed414781c0af0b35d9d575226da6b728bd1a720668fa05837184965b7"}, + {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d58ec349e0c2c0bc6669bf2cd4982d2f93bf067860d23a0ea1fe677b0f0b1e09"}, + {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd65695a8df1233309b701dec2539cc4b11e97d4fcc0f4185b4a12ce54db0491"}, + {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:665b21e95bc0fce5cab03b2e1d90ba9c66c510f1bb5fdc864f3a377d0f553f6b"}, + {file = "greenlet-3.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d3c59a06c2c28a81a026ff11fbf012081ea34fb9b7052f2ed0366e14896f0a1d"}, + {file = "greenlet-3.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415b9494ff6240b09af06b91a375731febe0090218e2898d2b85f9b92abcda0"}, + {file = "greenlet-3.1.0-cp38-cp38-win32.whl", hash = "sha256:1544b8dd090b494c55e60c4ff46e238be44fdc472d2589e943c241e0169bcea2"}, + {file = "greenlet-3.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:7f346d24d74c00b6730440f5eb8ec3fe5774ca8d1c9574e8e57c8671bb51b910"}, + {file = "greenlet-3.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:db1b3ccb93488328c74e97ff888604a8b95ae4f35f4f56677ca57a4fc3a4220b"}, + {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44cd313629ded43bb3b98737bba2f3e2c2c8679b55ea29ed73daea6b755fe8e7"}, + {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fad7a051e07f64e297e6e8399b4d6a3bdcad3d7297409e9a06ef8cbccff4f501"}, + {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3967dcc1cd2ea61b08b0b276659242cbce5caca39e7cbc02408222fb9e6ff39"}, + {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d45b75b0f3fd8d99f62eb7908cfa6d727b7ed190737dec7fe46d993da550b81a"}, + {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2d004db911ed7b6218ec5c5bfe4cf70ae8aa2223dffbb5b3c69e342bb253cb28"}, + {file = "greenlet-3.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9505a0c8579899057cbefd4ec34d865ab99852baf1ff33a9481eb3924e2da0b"}, + {file = "greenlet-3.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fd6e94593f6f9714dbad1aaba734b5ec04593374fa6638df61592055868f8b8"}, + {file = "greenlet-3.1.0-cp39-cp39-win32.whl", hash = "sha256:d0dd943282231480aad5f50f89bdf26690c995e8ff555f26d8a5b9887b559bcc"}, + {file = "greenlet-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:ac0adfdb3a21dc2a24ed728b61e72440d297d0fd3a577389df566651fcd08f97"}, + {file = "greenlet-3.1.0.tar.gz", hash = "sha256:b395121e9bbe8d02a750886f108d540abe66075e61e22f7353d9acb0b81be0f0"}, ] [package.extras] @@ -4012,13 +4020,13 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs [[package]] name = "importlib-resources" -version = "6.4.4" +version = "6.4.5" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"}, - {file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"}, + {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, + {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, ] [package.extras] @@ -4313,13 +4321,13 @@ files = [ [[package]] name = "kombu" -version = "5.4.0" +version = "5.4.1" description = "Messaging library for Python." optional = false python-versions = ">=3.8" files = [ - {file = "kombu-5.4.0-py3-none-any.whl", hash = "sha256:c8dd99820467610b4febbc7a9e8a0d3d7da2d35116b67184418b51cc520ea6b6"}, - {file = "kombu-5.4.0.tar.gz", hash = "sha256:ad200a8dbdaaa2bbc5f26d2ee7d707d9a1fded353a0f4bd751ce8c7d9f449c60"}, + {file = "kombu-5.4.1-py3-none-any.whl", hash = "sha256:621d365f234e4c089596f3a2510f1ade07026efc28caca426161d8f458786cab"}, + {file = "kombu-5.4.1.tar.gz", hash = "sha256:1c05178826dab811f8cab5b0a154d42a7a33d8bcdde9fa3d7b4582e43c3c03db"}, ] [package.dependencies] @@ -4333,7 +4341,7 @@ confluentkafka = ["confluent-kafka (>=2.2.0)"] consul = ["python-consul2 (==0.1.5)"] librabbitmq = ["librabbitmq (>=2.0.0)"] mongodb = ["pymongo (>=4.1.1)"] -msgpack = ["msgpack (==1.0.8)"] +msgpack = ["msgpack (==1.1.0)"] pyro = ["pyro4 (==4.82)"] qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] @@ -4385,13 +4393,13 @@ six = "*" [[package]] name = "langfuse" -version = "2.46.3" +version = "2.48.0" description = "A client library for accessing langfuse" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langfuse-2.46.3-py3-none-any.whl", hash = "sha256:59dcca4b13ea5f5c7f5a9344266116c3b8b998ae63274e4e9d0dabb51a47d361"}, - {file = "langfuse-2.46.3.tar.gz", hash = "sha256:a68c2dba630f53ccd473205164082ac1b29a1cbdb73500004daee72b5b522624"}, + {file = "langfuse-2.48.0-py3-none-any.whl", hash = "sha256:475b047e461f8a45e3c7d81b6a87e0b9e389c489d465b838aa69cbdd16eeacce"}, + {file = "langfuse-2.48.0.tar.gz", hash = "sha256:46e7e6e6e97fe03115a9f95d7f29b3fcd1848a9d1bb34608ebb42a3931919e45"}, ] [package.dependencies] @@ -4410,13 +4418,13 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langsmith" -version = "0.1.115" +version = "0.1.118" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.115-py3-none-any.whl", hash = "sha256:04e35cfd4c2d4ff1ea10bb577ff43957b05ebb3d9eb4e06e200701f4a2b4ac9f"}, - {file = "langsmith-0.1.115.tar.gz", hash = "sha256:3b775377d858d32354f3ee0dd1ed637068cfe9a1f13e7b3bfa82db1615cdffc9"}, + {file = "langsmith-0.1.118-py3-none-any.whl", hash = "sha256:f017127b3efb037da5e46ff4f8583e8192e7955191737240c327f3eadc144d7c"}, + {file = "langsmith-0.1.118.tar.gz", hash = "sha256:ff1ca06c92c6081250244ebbce5d0bb347b9d898d2e9b60a13b11f0f0720f09f"}, ] [package.dependencies] @@ -4886,15 +4894,15 @@ files = [ [[package]] name = "milvus-lite" -version = "2.4.9" +version = "2.4.10" description = "A lightweight version of Milvus wrapped with Python." optional = false python-versions = ">=3.7" files = [ - {file = "milvus_lite-2.4.9-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d3e617b3d68c09ad656d54bc3d8cc4ef6ef56c54015e1563d4fe4bcec6b7c90a"}, - {file = "milvus_lite-2.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6e7029282d6829b277ebb92f64e2370be72b938e34770e1eb649346bda5d1d7f"}, - {file = "milvus_lite-2.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9b8e991e4e433596f6a399a165c1a506f823ec9133332e03d7f8a114bff4550d"}, - {file = "milvus_lite-2.4.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:7f53e674602101cfbcf0a4a59d19eaa139dfd5580639f3040ad73d901f24fc0b"}, + {file = "milvus_lite-2.4.10-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fc4246d3ed7d1910847afce0c9ba18212e93a6e9b8406048436940578dfad5cb"}, + {file = "milvus_lite-2.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:74a8e07c5e3b057df17fbb46913388e84df1dc403a200f4e423799a58184c800"}, + {file = "milvus_lite-2.4.10-py3-none-manylinux2014_aarch64.whl", hash = "sha256:240c7386b747bad696ecb5bd1f58d491e86b9d4b92dccee3315ed7256256eddc"}, + {file = "milvus_lite-2.4.10-py3-none-manylinux2014_x86_64.whl", hash = "sha256:211d2e334a043f9282bdd9755f76b9b2d93b23bffa7af240919ffce6a8dfe325"}, ] [package.dependencies] @@ -5038,22 +5046,22 @@ tests = ["pytest (>=4.6)"] [[package]] name = "msal" -version = "1.30.0" +version = "1.31.0" description = "The Microsoft Authentication Library (MSAL) for Python library enables your app to access the Microsoft Cloud by supporting authentication of users with Microsoft Azure Active Directory accounts (AAD) and Microsoft Accounts (MSA) using industry standard OAuth2 and OpenID Connect." optional = false python-versions = ">=3.7" files = [ - {file = "msal-1.30.0-py3-none-any.whl", hash = "sha256:423872177410cb61683566dc3932db7a76f661a5d2f6f52f02a047f101e1c1de"}, - {file = "msal-1.30.0.tar.gz", hash = "sha256:b4bf00850092e465157d814efa24a18f788284c9a479491024d62903085ea2fb"}, + {file = "msal-1.31.0-py3-none-any.whl", hash = "sha256:96bc37cff82ebe4b160d5fc0f1196f6ca8b50e274ecd0ec5bf69c438514086e7"}, + {file = "msal-1.31.0.tar.gz", hash = "sha256:2c4f189cf9cc8f00c80045f66d39b7c0f3ed45873fd3d1f2af9f22db2e12ff4b"}, ] [package.dependencies] -cryptography = ">=2.5,<45" +cryptography = ">=2.5,<46" PyJWT = {version = ">=1.0.0,<3", extras = ["crypto"]} requests = ">=2.0.0,<3" [package.extras] -broker = ["pymsalruntime (>=0.13.2,<0.17)"] +broker = ["pymsalruntime (>=0.14,<0.18)", "pymsalruntime (>=0.17,<0.18)"] [[package]] name = "msal-extensions" @@ -5110,103 +5118,108 @@ async = ["aiodns", "aiohttp (>=3.0)"] [[package]] name = "multidict" -version = "6.0.5" +version = "6.1.0" description = "multidict implementation" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, - {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, - {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, - {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, - {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, - {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, - {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, - {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, - {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, - {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, - {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, - {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, - {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, - {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, - {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, - {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, - {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, - {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a114d03b938376557927ab23f1e950827c3b893ccb94b62fd95d430fd0e5cf53"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1c416351ee6271b2f49b56ad7f308072f6f44b37118d69c2cad94f3fa8a40d5"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b5d83030255983181005e6cfbac1617ce9746b219bc2aad52201ad121226581"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e97b5e938051226dc025ec80980c285b053ffb1e25a3db2a3aa3bc046bf7f56"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d618649d4e70ac6efcbba75be98b26ef5078faad23592f9b51ca492953012429"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10524ebd769727ac77ef2278390fb0068d83f3acb7773792a5080f2b0abf7748"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ff3827aef427c89a25cc96ded1759271a93603aba9fb977a6d264648ebf989db"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06809f4f0f7ab7ea2cabf9caca7d79c22c0758b58a71f9d32943ae13c7ace056"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f179dee3b863ab1c59580ff60f9d99f632f34ccb38bf67a33ec6b3ecadd0fd76"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:aaed8b0562be4a0876ee3b6946f6869b7bcdb571a5d1496683505944e268b160"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c8b88a2ccf5493b6c8da9076fb151ba106960a2df90c2633f342f120751a9e7"}, + {file = "multidict-6.1.0-cp310-cp310-win32.whl", hash = "sha256:4a9cb68166a34117d6646c0023c7b759bf197bee5ad4272f420a0141d7eb03a0"}, + {file = "multidict-6.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:20b9b5fbe0b88d0bdef2012ef7dee867f874b72528cf1d08f1d59b0e3850129d"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3efe2c2cb5763f2f1b275ad2bf7a287d3f7ebbef35648a9726e3b69284a4f3d6"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7053d3b0353a8b9de430a4f4b4268ac9a4fb3481af37dfe49825bf45ca24156"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27e5fc84ccef8dfaabb09d82b7d179c7cf1a3fbc8a966f8274fcb4ab2eb4cadb"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2b90b43e696f25c62656389d32236e049568b39320e2735d51f08fd362761b"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d83a047959d38a7ff552ff94be767b7fd79b831ad1cd9920662db05fec24fe72"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1a9dd711d0877a1ece3d2e4fea11a8e75741ca21954c919406b44e7cf971304"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec2abea24d98246b94913b76a125e855eb5c434f7c46546046372fe60f666351"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4867cafcbc6585e4b678876c489b9273b13e9fff9f6d6d66add5e15d11d926cb"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5b48204e8d955c47c55b72779802b219a39acc3ee3d0116d5080c388970b76e3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8fff389528cad1618fb4b26b95550327495462cd745d879a8c7c2115248e399"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a7a9541cd308eed5e30318430a9c74d2132e9a8cb46b901326272d780bf2d423"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:da1758c76f50c39a2efd5e9859ce7d776317eb1dd34317c8152ac9251fc574a3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c943a53e9186688b45b323602298ab727d8865d8c9ee0b17f8d62d14b56f0753"}, + {file = "multidict-6.1.0-cp311-cp311-win32.whl", hash = "sha256:90f8717cb649eea3504091e640a1b8568faad18bd4b9fcd692853a04475a4b80"}, + {file = "multidict-6.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:82176036e65644a6cc5bd619f65f6f19781e8ec2e5330f51aa9ada7504cc1926"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04772ed465fa3cc947db808fa306d79b43e896beb677a56fb2347ca1a49c1fa"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6180c0ae073bddeb5a97a38c03f30c233e0a4d39cd86166251617d1bbd0af436"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:071120490b47aa997cca00666923a83f02c7fbb44f71cf7f136df753f7fa8761"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50b3a2710631848991d0bf7de077502e8994c804bb805aeb2925a981de58ec2e"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b58c621844d55e71c1b7f7c498ce5aa6985d743a1a59034c57a905b3f153c1ef"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b6d90641869892caa9ca42ff913f7ff1c5ece06474fbd32fb2cf6834726c95"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b820514bfc0b98a30e3d85462084779900347e4d49267f747ff54060cc33925"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a9b09aba0c5b48c53761b7c720aaaf7cf236d5fe394cd399c7ba662d5f9966"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e16bf3e5fc9f44632affb159d30a437bfe286ce9e02754759be5536b169b305"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:76f364861c3bfc98cbbcbd402d83454ed9e01a5224bb3a28bf70002a230f73e2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:820c661588bd01a0aa62a1283f20d2be4281b086f80dad9e955e690c75fb54a2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:0e5f362e895bc5b9e67fe6e4ded2492d8124bdf817827f33c5b46c2fe3ffaca6"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ec660d19bbc671e3a6443325f07263be452c453ac9e512f5eb935e7d4ac28b3"}, + {file = "multidict-6.1.0-cp312-cp312-win32.whl", hash = "sha256:58130ecf8f7b8112cdb841486404f1282b9c86ccb30d3519faf301b2e5659133"}, + {file = "multidict-6.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:188215fc0aafb8e03341995e7c4797860181562380f81ed0a87ff455b70bf1f1"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d569388c381b24671589335a3be6e1d45546c2988c2ebe30fdcada8457a31008"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:052e10d2d37810b99cc170b785945421141bf7bb7d2f8799d431e7db229c385f"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f90c822a402cb865e396a504f9fc8173ef34212a342d92e362ca498cad308e28"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b225d95519a5bf73860323e633a664b0d85ad3d5bede6d30d95b35d4dfe8805b"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:23bfd518810af7de1116313ebd9092cb9aa629beb12f6ed631ad53356ed6b86c"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c09fcfdccdd0b57867577b719c69e347a436b86cd83747f179dbf0cc0d4c1f3"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf6bea52ec97e95560af5ae576bdac3aa3aae0b6758c6efa115236d9e07dae44"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57feec87371dbb3520da6192213c7d6fc892d5589a93db548331954de8248fd2"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0c3f390dc53279cbc8ba976e5f8035eab997829066756d811616b652b00a23a3"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:59bfeae4b25ec05b34f1956eaa1cb38032282cd4dfabc5056d0a1ec4d696d3aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b2f59caeaf7632cc633b5cf6fc449372b83bbdf0da4ae04d5be36118e46cc0aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:37bb93b2178e02b7b618893990941900fd25b6b9ac0fa49931a40aecdf083fe4"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4e9f48f58c2c523d5a06faea47866cd35b32655c46b443f163d08c6d0ddb17d6"}, + {file = "multidict-6.1.0-cp313-cp313-win32.whl", hash = "sha256:3a37ffb35399029b45c6cc33640a92bef403c9fd388acce75cdc88f58bd19a81"}, + {file = "multidict-6.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:e9aa71e15d9d9beaad2c6b9319edcdc0a49a43ef5c0a4c8265ca9ee7d6c67774"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:db7457bac39421addd0c8449933ac32d8042aae84a14911a757ae6ca3eef1392"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d094ddec350a2fb899fec68d8353c78233debde9b7d8b4beeafa70825f1c281a"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5845c1fd4866bb5dd3125d89b90e57ed3138241540897de748cdf19de8a2fca2"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9079dfc6a70abe341f521f78405b8949f96db48da98aeb43f9907f342f627cdc"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3914f5aaa0f36d5d60e8ece6a308ee1c9784cd75ec8151062614657a114c4478"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c08be4f460903e5a9d0f76818db3250f12e9c344e79314d1d570fc69d7f4eae4"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d093be959277cb7dee84b801eb1af388b6ad3ca6a6b6bf1ed7585895789d027d"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3702ea6872c5a2a4eeefa6ffd36b042e9773f05b1f37ae3ef7264b1163c2dcf6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2090f6a85cafc5b2db085124d752757c9d251548cedabe9bd31afe6363e0aff2"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:f67f217af4b1ff66c68a87318012de788dd95fcfeb24cc889011f4e1c7454dfd"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:189f652a87e876098bbc67b4da1049afb5f5dfbaa310dd67c594b01c10388db6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:6bb5992037f7a9eff7991ebe4273ea7f51f1c1c511e6a2ce511d0e7bdb754492"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f4c2b9e770c4e393876e35a7046879d195cd123b4f116d299d442b335bcd"}, + {file = "multidict-6.1.0-cp38-cp38-win32.whl", hash = "sha256:e27bbb6d14416713a8bd7aaa1313c0fc8d44ee48d74497a0ff4c3a1b6ccb5167"}, + {file = "multidict-6.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:22f3105d4fb15c8f57ff3959a58fcab6ce36814486500cd7485651230ad4d4ef"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4e18b656c5e844539d506a0a06432274d7bd52a7487e6828c63a63d69185626c"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a185f876e69897a6f3325c3f19f26a297fa058c5e456bfcff8015e9a27e83ae1"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab7c4ceb38d91570a650dba194e1ca87c2b543488fe9309b4212694174fd539c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e617fb6b0b6953fffd762669610c1c4ffd05632c138d61ac7e14ad187870669c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:16e5f4bf4e603eb1fdd5d8180f1a25f30056f22e55ce51fb3d6ad4ab29f7d96f"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4c035da3f544b1882bac24115f3e2e8760f10a0107614fc9839fd232200b875"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:957cf8e4b6e123a9eea554fa7ebc85674674b713551de587eb318a2df3e00255"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:483a6aea59cb89904e1ceabd2b47368b5600fb7de78a6e4a2c2987b2d256cf30"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:87701f25a2352e5bf7454caa64757642734da9f6b11384c1f9d1a8e699758057"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:682b987361e5fd7a139ed565e30d81fd81e9629acc7d925a205366877d8c8657"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce2186a7df133a9c895dea3331ddc5ddad42cdd0d1ea2f0a51e5d161e4762f28"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9f636b730f7e8cb19feb87094949ba54ee5357440b9658b2a32a5ce4bce53972"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:73eae06aa53af2ea5270cc066dcaf02cc60d2994bbb2c4ef5764949257d10f43"}, + {file = "multidict-6.1.0-cp39-cp39-win32.whl", hash = "sha256:1ca0083e80e791cffc6efce7660ad24af66c8d4079d2a750b29001b53ff59ada"}, + {file = "multidict-6.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:aa466da5b15ccea564bdab9c89175c762bc12825f4659c11227f515cee76fa4a"}, + {file = "multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506"}, + {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, ] +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} + [[package]] name = "multiprocess" version = "0.70.16" @@ -6219,19 +6232,19 @@ xmp = ["defusedxml"] [[package]] name = "platformdirs" -version = "4.2.2" +version = "4.3.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, + {file = "platformdirs-4.3.2-py3-none-any.whl", hash = "sha256:eb1c8582560b34ed4ba105009a4badf7f6f85768b30126f351328507b2beb617"}, + {file = "platformdirs-4.3.2.tar.gz", hash = "sha256:9e5e27a08aa095dd127b9f2e764d74254f482fef22b0970773bfba79d091ab8c"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] [[package]] name = "plotly" @@ -6295,13 +6308,13 @@ tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "p [[package]] name = "posthog" -version = "3.6.3" +version = "3.6.5" description = "Integrate PostHog into any python application." optional = false python-versions = "*" files = [ - {file = "posthog-3.6.3-py2.py3-none-any.whl", hash = "sha256:cdd6c5d8919fd6158bbc4103bccc7129c712d8104dc33828be02bada7b6320a4"}, - {file = "posthog-3.6.3.tar.gz", hash = "sha256:6e1104a20638eab2b5d9cde6b6202a2900d67436237b3ac3521614ec17686701"}, + {file = "posthog-3.6.5-py2.py3-none-any.whl", hash = "sha256:f8b7c573826b061a1d22c9495169c38ebe83a1df2729f49c7129a9c23a02acf6"}, + {file = "posthog-3.6.5.tar.gz", hash = "sha256:7fd3ca809e15476c35f75d18cd6bba31395daf0a17b75242965c469fb6292510"}, ] [package.dependencies] @@ -6586,24 +6599,24 @@ test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] [[package]] name = "pyasn1" -version = "0.6.0" +version = "0.6.1" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = ">=3.8" files = [ - {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, - {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] [[package]] name = "pyasn1-modules" -version = "0.4.0" +version = "0.4.1" description = "A collection of ASN.1-based protocols modules" optional = false python-versions = ">=3.8" files = [ - {file = "pyasn1_modules-0.4.0-py3-none-any.whl", hash = "sha256:be04f15b66c206eed667e0bb5ab27e2b1855ea54a842e5037738099e8ca4ae0b"}, - {file = "pyasn1_modules-0.4.0.tar.gz", hash = "sha256:831dbcea1b177b28c9baddf4c6d1013c24c3accd14a1873fffaa6a2e905f17b6"}, + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, ] [package.dependencies] @@ -7012,24 +7025,24 @@ files = [ [[package]] name = "pyreadline3" -version = "3.4.1" +version = "3.4.3" description = "A python implementation of GNU readline." optional = false python-versions = "*" files = [ - {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, - {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, + {file = "pyreadline3-3.4.3-py3-none-any.whl", hash = "sha256:f832c5898f4f9a0f81d48a8c499b39d0179de1a465ea3def1a7e7231840b4ed6"}, + {file = "pyreadline3-3.4.3.tar.gz", hash = "sha256:ebab0baca37f50e2faa1dd99a6da1c75de60e0d68a3b229c134bbd12786250e2"}, ] [[package]] name = "pytest" -version = "8.3.2" +version = "8.3.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.3.2-py3-none-any.whl", hash = "sha256:4ba08f9ae7dcf84ded419494d229b48d0903ea6407b030eaec46df5e6a73bba5"}, - {file = "pytest-8.3.2.tar.gz", hash = "sha256:c132345d12ce551242c87269de812483f5bcc87cdbb4722e48487ba194f9fdce"}, + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, ] [package.dependencies] @@ -7065,21 +7078,21 @@ histogram = ["pygal", "pygaljs"] [[package]] name = "pytest-env" -version = "1.1.3" +version = "1.1.4" description = "pytest plugin that allows you to add environment variables." optional = false python-versions = ">=3.8" files = [ - {file = "pytest_env-1.1.3-py3-none-any.whl", hash = "sha256:aada77e6d09fcfb04540a6e462c58533c37df35fa853da78707b17ec04d17dfc"}, - {file = "pytest_env-1.1.3.tar.gz", hash = "sha256:fcd7dc23bb71efd3d35632bde1bbe5ee8c8dc4489d6617fb010674880d96216b"}, + {file = "pytest_env-1.1.4-py3-none-any.whl", hash = "sha256:a4212056d4d440febef311a98fdca56c31256d58fb453d103cba4e8a532b721d"}, + {file = "pytest_env-1.1.4.tar.gz", hash = "sha256:86653658da8f11c6844975db955746c458a9c09f1e64957603161e2ff93f5133"}, ] [package.dependencies] -pytest = ">=7.4.3" +pytest = ">=8.3.2" tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [package.extras] -test = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "pytest-mock (>=3.12)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] [[package]] name = "pytest-mock" @@ -7293,13 +7306,13 @@ XlsxWriter = ">=0.5.7" [[package]] name = "pytz" -version = "2024.1" +version = "2024.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" files = [ - {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, - {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, ] [[package]] @@ -7642,90 +7655,105 @@ rpds-py = ">=0.7.0" [[package]] name = "regex" -version = "2024.7.24" +version = "2024.9.11" description = "Alternative regular expression module, to replace re." optional = false python-versions = ">=3.8" files = [ - {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b0d3f567fafa0633aee87f08b9276c7062da9616931382993c03808bb68ce"}, - {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3426de3b91d1bc73249042742f45c2148803c111d1175b283270177fdf669024"}, - {file = "regex-2024.7.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f273674b445bcb6e4409bf8d1be67bc4b58e8b46fd0d560055d515b8830063cd"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23acc72f0f4e1a9e6e9843d6328177ae3074b4182167e34119ec7233dfeccf53"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65fd3d2e228cae024c411c5ccdffae4c315271eee4a8b839291f84f796b34eca"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c414cbda77dbf13c3bc88b073a1a9f375c7b0cb5e115e15d4b73ec3a2fbc6f59"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7a89eef64b5455835f5ed30254ec19bf41f7541cd94f266ab7cbd463f00c41"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19c65b00d42804e3fbea9708f0937d157e53429a39b7c61253ff15670ff62cb5"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7a5486ca56c8869070a966321d5ab416ff0f83f30e0e2da1ab48815c8d165d46"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f51f9556785e5a203713f5efd9c085b4a45aecd2a42573e2b5041881b588d1f"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a4997716674d36a82eab3e86f8fa77080a5d8d96a389a61ea1d0e3a94a582cf7"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c0abb5e4e8ce71a61d9446040c1e86d4e6d23f9097275c5bd49ed978755ff0fe"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:18300a1d78cf1290fa583cd8b7cde26ecb73e9f5916690cf9d42de569c89b1ce"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:416c0e4f56308f34cdb18c3f59849479dde5b19febdcd6e6fa4d04b6c31c9faa"}, - {file = "regex-2024.7.24-cp310-cp310-win32.whl", hash = "sha256:fb168b5924bef397b5ba13aabd8cf5df7d3d93f10218d7b925e360d436863f66"}, - {file = "regex-2024.7.24-cp310-cp310-win_amd64.whl", hash = "sha256:6b9fc7e9cc983e75e2518496ba1afc524227c163e43d706688a6bb9eca41617e"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:382281306e3adaaa7b8b9ebbb3ffb43358a7bbf585fa93821300a418bb975281"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fdd1384619f406ad9037fe6b6eaa3de2749e2e12084abc80169e8e075377d3b"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3d974d24edb231446f708c455fd08f94c41c1ff4f04bcf06e5f36df5ef50b95a"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec4419a3fe6cf8a4795752596dfe0adb4aea40d3683a132bae9c30b81e8d73"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb563dd3aea54c797adf513eeec819c4213d7dbfc311874eb4fd28d10f2ff0f2"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45104baae8b9f67569f0f1dca5e1f1ed77a54ae1cd8b0b07aba89272710db61e"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:994448ee01864501912abf2bad9203bffc34158e80fe8bfb5b031f4f8e16da51"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fac296f99283ac232d8125be932c5cd7644084a30748fda013028c815ba3364"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7e37e809b9303ec3a179085415cb5f418ecf65ec98cdfe34f6a078b46ef823ee"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:01b689e887f612610c869421241e075c02f2e3d1ae93a037cb14f88ab6a8934c"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f6442f0f0ff81775eaa5b05af8a0ffa1dda36e9cf6ec1e0d3d245e8564b684ce"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:871e3ab2838fbcb4e0865a6e01233975df3a15e6fce93b6f99d75cacbd9862d1"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c918b7a1e26b4ab40409820ddccc5d49871a82329640f5005f73572d5eaa9b5e"}, - {file = "regex-2024.7.24-cp311-cp311-win32.whl", hash = "sha256:2dfbb8baf8ba2c2b9aa2807f44ed272f0913eeeba002478c4577b8d29cde215c"}, - {file = "regex-2024.7.24-cp311-cp311-win_amd64.whl", hash = "sha256:538d30cd96ed7d1416d3956f94d54e426a8daf7c14527f6e0d6d425fcb4cca52"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fe4ebef608553aff8deb845c7f4f1d0740ff76fa672c011cc0bacb2a00fbde86"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:74007a5b25b7a678459f06559504f1eec2f0f17bca218c9d56f6a0a12bfffdad"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7df9ea48641da022c2a3c9c641650cd09f0cd15e8908bf931ad538f5ca7919c9"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1141a1dcc32904c47f6846b040275c6e5de0bf73f17d7a409035d55b76f289"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80c811cfcb5c331237d9bad3bea2c391114588cf4131707e84d9493064d267f9"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7214477bf9bd195894cf24005b1e7b496f46833337b5dedb7b2a6e33f66d962c"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d55588cba7553f0b6ec33130bc3e114b355570b45785cebdc9daed8c637dd440"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558a57cfc32adcf19d3f791f62b5ff564922942e389e3cfdb538a23d65a6b610"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a512eed9dfd4117110b1881ba9a59b31433caed0c4101b361f768e7bcbaf93c5"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:86b17ba823ea76256b1885652e3a141a99a5c4422f4a869189db328321b73799"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5eefee9bfe23f6df09ffb6dfb23809f4d74a78acef004aa904dc7c88b9944b05"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:731fcd76bbdbf225e2eb85b7c38da9633ad3073822f5ab32379381e8c3c12e94"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eaef80eac3b4cfbdd6de53c6e108b4c534c21ae055d1dbea2de6b3b8ff3def38"}, - {file = "regex-2024.7.24-cp312-cp312-win32.whl", hash = "sha256:185e029368d6f89f36e526764cf12bf8d6f0e3a2a7737da625a76f594bdfcbfc"}, - {file = "regex-2024.7.24-cp312-cp312-win_amd64.whl", hash = "sha256:2f1baff13cc2521bea83ab2528e7a80cbe0ebb2c6f0bfad15be7da3aed443908"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:66b4c0731a5c81921e938dcf1a88e978264e26e6ac4ec96a4d21ae0354581ae0"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:88ecc3afd7e776967fa16c80f974cb79399ee8dc6c96423321d6f7d4b881c92b"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64bd50cf16bcc54b274e20235bf8edbb64184a30e1e53873ff8d444e7ac656b2"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb462f0e346fcf41a901a126b50f8781e9a474d3927930f3490f38a6e73b6950"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a82465ebbc9b1c5c50738536fdfa7cab639a261a99b469c9d4c7dcbb2b3f1e57"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68a8f8c046c6466ac61a36b65bb2395c74451df2ffb8458492ef49900efed293"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac8e84fff5d27420f3c1e879ce9929108e873667ec87e0c8eeb413a5311adfe"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba2537ef2163db9e6ccdbeb6f6424282ae4dea43177402152c67ef869cf3978b"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:43affe33137fcd679bdae93fb25924979517e011f9dea99163f80b82eadc7e53"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:c9bb87fdf2ab2370f21e4d5636e5317775e5d51ff32ebff2cf389f71b9b13750"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:945352286a541406f99b2655c973852da7911b3f4264e010218bbc1cc73168f2"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8bc593dcce679206b60a538c302d03c29b18e3d862609317cb560e18b66d10cf"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:3f3b6ca8eae6d6c75a6cff525c8530c60e909a71a15e1b731723233331de4169"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c51edc3541e11fbe83f0c4d9412ef6c79f664a3745fab261457e84465ec9d5a8"}, - {file = "regex-2024.7.24-cp38-cp38-win32.whl", hash = "sha256:d0a07763776188b4db4c9c7fb1b8c494049f84659bb387b71c73bbc07f189e96"}, - {file = "regex-2024.7.24-cp38-cp38-win_amd64.whl", hash = "sha256:8fd5afd101dcf86a270d254364e0e8dddedebe6bd1ab9d5f732f274fa00499a5"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0ffe3f9d430cd37d8fa5632ff6fb36d5b24818c5c986893063b4e5bdb84cdf24"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25419b70ba00a16abc90ee5fce061228206173231f004437730b67ac77323f0d"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33e2614a7ce627f0cdf2ad104797d1f68342d967de3695678c0cb84f530709f8"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33a0021893ede5969876052796165bab6006559ab845fd7b515a30abdd990dc"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04ce29e2c5fedf296b1a1b0acc1724ba93a36fb14031f3abfb7abda2806c1535"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b16582783f44fbca6fcf46f61347340c787d7530d88b4d590a397a47583f31dd"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:836d3cc225b3e8a943d0b02633fb2f28a66e281290302a79df0e1eaa984ff7c1"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:438d9f0f4bc64e8dea78274caa5af971ceff0f8771e1a2333620969936ba10be"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:973335b1624859cb0e52f96062a28aa18f3a5fc77a96e4a3d6d76e29811a0e6e"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c5e69fd3eb0b409432b537fe3c6f44ac089c458ab6b78dcec14478422879ec5f"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fbf8c2f00904eaf63ff37718eb13acf8e178cb940520e47b2f05027f5bb34ce3"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2757ace61bc4061b69af19e4689fa4416e1a04840f33b441034202b5cd02d4"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:44fc61b99035fd9b3b9453f1713234e5a7c92a04f3577252b45feefe1b327759"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:84c312cdf839e8b579f504afcd7b65f35d60b6285d892b19adea16355e8343c9"}, - {file = "regex-2024.7.24-cp39-cp39-win32.whl", hash = "sha256:ca5b2028c2f7af4e13fb9fc29b28d0ce767c38c7facdf64f6c2cd040413055f1"}, - {file = "regex-2024.7.24-cp39-cp39-win_amd64.whl", hash = "sha256:7c479f5ae937ec9985ecaf42e2e10631551d909f203e31308c12d703922742f9"}, - {file = "regex-2024.7.24.tar.gz", hash = "sha256:9cfd009eed1a46b27c14039ad5bbc5e71b6367c5b2e6d5f5da0ea91600817506"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1494fa8725c285a81d01dc8c06b55287a1ee5e0e382d8413adc0a9197aac6408"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0e12c481ad92d129c78f13a2a3662317e46ee7ef96c94fd332e1c29131875b7d"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16e13a7929791ac1216afde26f712802e3df7bf0360b32e4914dca3ab8baeea5"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46989629904bad940bbec2106528140a218b4a36bb3042d8406980be1941429c"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a906ed5e47a0ce5f04b2c981af1c9acf9e8696066900bf03b9d7879a6f679fc8"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a091b0550b3b0207784a7d6d0f1a00d1d1c8a11699c1a4d93db3fbefc3ad35"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ddcd9a179c0a6fa8add279a4444015acddcd7f232a49071ae57fa6e278f1f71"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6b41e1adc61fa347662b09398e31ad446afadff932a24807d3ceb955ed865cc8"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ced479f601cd2f8ca1fd7b23925a7e0ad512a56d6e9476f79b8f381d9d37090a"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:635a1d96665f84b292e401c3d62775851aedc31d4f8784117b3c68c4fcd4118d"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c0256beda696edcf7d97ef16b2a33a8e5a875affd6fa6567b54f7c577b30a137"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ce4f1185db3fbde8ed8aa223fc9620f276c58de8b0d4f8cc86fd1360829edb6"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:09d77559e80dcc9d24570da3745ab859a9cf91953062e4ab126ba9d5993688ca"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a22ccefd4db3f12b526eccb129390942fe874a3a9fdbdd24cf55773a1faab1a"}, + {file = "regex-2024.9.11-cp310-cp310-win32.whl", hash = "sha256:f745ec09bc1b0bd15cfc73df6fa4f726dcc26bb16c23a03f9e3367d357eeedd0"}, + {file = "regex-2024.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:01c2acb51f8a7d6494c8c5eafe3d8e06d76563d8a8a4643b37e9b2dd8a2ff623"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2cce2449e5927a0bf084d346da6cd5eb016b2beca10d0013ab50e3c226ffc0df"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b37fa423beefa44919e009745ccbf353d8c981516e807995b2bd11c2c77d268"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:64ce2799bd75039b480cc0360907c4fb2f50022f030bf9e7a8705b636e408fad"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4cc92bb6db56ab0c1cbd17294e14f5e9224f0cc6521167ef388332604e92679"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d05ac6fa06959c4172eccd99a222e1fbf17b5670c4d596cb1e5cde99600674c4"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:040562757795eeea356394a7fb13076ad4f99d3c62ab0f8bdfb21f99a1f85664"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6113c008a7780792efc80f9dfe10ba0cd043cbf8dc9a76ef757850f51b4edc50"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e5fb5f77c8745a60105403a774fe2c1759b71d3e7b4ca237a5e67ad066c7199"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:54d9ff35d4515debf14bc27f1e3b38bfc453eff3220f5bce159642fa762fe5d4"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:df5cbb1fbc74a8305b6065d4ade43b993be03dbe0f8b30032cced0d7740994bd"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7fb89ee5d106e4a7a51bce305ac4efb981536301895f7bdcf93ec92ae0d91c7f"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a738b937d512b30bf75995c0159c0ddf9eec0775c9d72ac0202076c72f24aa96"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e28f9faeb14b6f23ac55bfbbfd3643f5c7c18ede093977f1df249f73fd22c7b1"}, + {file = "regex-2024.9.11-cp311-cp311-win32.whl", hash = "sha256:18e707ce6c92d7282dfce370cd205098384b8ee21544e7cb29b8aab955b66fa9"}, + {file = "regex-2024.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:313ea15e5ff2a8cbbad96ccef6be638393041b0a7863183c2d31e0c6116688cf"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b0d0a6c64fcc4ef9c69bd5b3b3626cc3776520a1637d8abaa62b9edc147a58f7"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:49b0e06786ea663f933f3710a51e9385ce0cba0ea56b67107fd841a55d56a231"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5b513b6997a0b2f10e4fd3a1313568e373926e8c252bd76c960f96fd039cd28d"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee439691d8c23e76f9802c42a95cfeebf9d47cf4ffd06f18489122dbb0a7ad64"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a8f877c89719d759e52783f7fe6e1c67121076b87b40542966c02de5503ace42"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23b30c62d0f16827f2ae9f2bb87619bc4fba2044911e2e6c2eb1af0161cdb766"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85ab7824093d8f10d44330fe1e6493f756f252d145323dd17ab6b48733ff6c0a"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8dee5b4810a89447151999428fe096977346cf2f29f4d5e29609d2e19e0199c9"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98eeee2f2e63edae2181c886d7911ce502e1292794f4c5ee71e60e23e8d26b5d"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:57fdd2e0b2694ce6fc2e5ccf189789c3e2962916fb38779d3e3521ff8fe7a822"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d552c78411f60b1fdaafd117a1fca2f02e562e309223b9d44b7de8be451ec5e0"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a0b2b80321c2ed3fcf0385ec9e51a12253c50f146fddb2abbb10f033fe3d049a"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:18406efb2f5a0e57e3a5881cd9354c1512d3bb4f5c45d96d110a66114d84d23a"}, + {file = "regex-2024.9.11-cp312-cp312-win32.whl", hash = "sha256:e464b467f1588e2c42d26814231edecbcfe77f5ac414d92cbf4e7b55b2c2a776"}, + {file = "regex-2024.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:9e8719792ca63c6b8340380352c24dcb8cd7ec49dae36e963742a275dfae6009"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c157bb447303070f256e084668b702073db99bbb61d44f85d811025fcf38f784"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4db21ece84dfeefc5d8a3863f101995de646c6cb0536952c321a2650aa202c36"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:220e92a30b426daf23bb67a7962900ed4613589bab80382be09b48896d211e92"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1ae19e64c14c7ec1995f40bd932448713d3c73509e82d8cd7744dc00e29e86"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f47cd43a5bfa48f86925fe26fbdd0a488ff15b62468abb5d2a1e092a4fb10e85"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d4a76b96f398697fe01117093613166e6aa8195d63f1b4ec3f21ab637632963"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ea51dcc0835eea2ea31d66456210a4e01a076d820e9039b04ae8d17ac11dee6"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7aaa315101c6567a9a45d2839322c51c8d6e81f67683d529512f5bcfb99c802"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c57d08ad67aba97af57a7263c2d9006d5c404d721c5f7542f077f109ec2a4a29"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8404bf61298bb6f8224bb9176c1424548ee1181130818fcd2cbffddc768bed8"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dd4490a33eb909ef5078ab20f5f000087afa2a4daa27b4c072ccb3cb3050ad84"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:eee9130eaad130649fd73e5cd92f60e55708952260ede70da64de420cdcad554"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a2644a93da36c784e546de579ec1806bfd2763ef47babc1b03d765fe560c9f8"}, + {file = "regex-2024.9.11-cp313-cp313-win32.whl", hash = "sha256:e997fd30430c57138adc06bba4c7c2968fb13d101e57dd5bb9355bf8ce3fa7e8"}, + {file = "regex-2024.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:042c55879cfeb21a8adacc84ea347721d3d83a159da6acdf1116859e2427c43f"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:35f4a6f96aa6cb3f2f7247027b07b15a374f0d5b912c0001418d1d55024d5cb4"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:55b96e7ce3a69a8449a66984c268062fbaa0d8ae437b285428e12797baefce7e"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb130fccd1a37ed894824b8c046321540263013da72745d755f2d35114b81a60"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:323c1f04be6b2968944d730e5c2091c8c89767903ecaa135203eec4565ed2b2b"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be1c8ed48c4c4065ecb19d882a0ce1afe0745dfad8ce48c49586b90a55f02366"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b5b029322e6e7b94fff16cd120ab35a253236a5f99a79fb04fda7ae71ca20ae8"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6fff13ef6b5f29221d6904aa816c34701462956aa72a77f1f151a8ec4f56aeb"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:587d4af3979376652010e400accc30404e6c16b7df574048ab1f581af82065e4"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:079400a8269544b955ffa9e31f186f01d96829110a3bf79dc338e9910f794fca"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f9268774428ec173654985ce55fc6caf4c6d11ade0f6f914d48ef4719eb05ebb"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:23f9985c8784e544d53fc2930fc1ac1a7319f5d5332d228437acc9f418f2f168"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2941333154baff9838e88aa71c1d84f4438189ecc6021a12c7573728b5838e"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:e93f1c331ca8e86fe877a48ad64e77882c0c4da0097f2212873a69bbfea95d0c"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:846bc79ee753acf93aef4184c040d709940c9d001029ceb7b7a52747b80ed2dd"}, + {file = "regex-2024.9.11-cp38-cp38-win32.whl", hash = "sha256:c94bb0a9f1db10a1d16c00880bdebd5f9faf267273b8f5bd1878126e0fbde771"}, + {file = "regex-2024.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:2b08fce89fbd45664d3df6ad93e554b6c16933ffa9d55cb7e01182baaf971508"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:07f45f287469039ffc2c53caf6803cd506eb5f5f637f1d4acb37a738f71dd066"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4838e24ee015101d9f901988001038f7f0d90dc0c3b115541a1365fb439add62"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6edd623bae6a737f10ce853ea076f56f507fd7726bee96a41ee3d68d347e4d16"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c69ada171c2d0e97a4b5aa78fbb835e0ffbb6b13fc5da968c09811346564f0d3"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02087ea0a03b4af1ed6ebab2c54d7118127fee8d71b26398e8e4b05b78963199"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69dee6a020693d12a3cf892aba4808fe168d2a4cef368eb9bf74f5398bfd4ee8"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297f54910247508e6e5cae669f2bc308985c60540a4edd1c77203ef19bfa63ca"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ecea58b43a67b1b79805f1a0255730edaf5191ecef84dbc4cc85eb30bc8b63b9"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eab4bb380f15e189d1313195b062a6aa908f5bd687a0ceccd47c8211e9cf0d4a"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0cbff728659ce4bbf4c30b2a1be040faafaa9eca6ecde40aaff86f7889f4ab39"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:54c4a097b8bc5bb0dfc83ae498061d53ad7b5762e00f4adaa23bee22b012e6ba"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:73d6d2f64f4d894c96626a75578b0bf7d9e56dcda8c3d037a2118fdfe9b1c664"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e53b5fbab5d675aec9f0c501274c467c0f9a5d23696cfc94247e1fb56501ed89"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ffbcf9221e04502fc35e54d1ce9567541979c3fdfb93d2c554f0ca583a19b35"}, + {file = "regex-2024.9.11-cp39-cp39-win32.whl", hash = "sha256:e4c22e1ac1f1ec1e09f72e6c44d8f2244173db7eb9629cc3a346a8d7ccc31142"}, + {file = "regex-2024.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:faa3c142464efec496967359ca99696c896c591c56c53506bac1ad465f66e919"}, + {file = "regex-2024.9.11.tar.gz", hash = "sha256:6c188c307e8433bcb63dc1915022deb553b4203a70722fc542c363bf120a01fd"}, ] [[package]] @@ -7831,13 +7859,13 @@ requests = "2.31.0" [[package]] name = "rich" -version = "13.8.0" +version = "13.8.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.8.0-py3-none-any.whl", hash = "sha256:2e85306a063b9492dffc86278197a60cbece75bcb766022f3436f567cae11bdc"}, - {file = "rich-13.8.0.tar.gz", hash = "sha256:a5ac1f1cd448ade0d59cc3356f7db7a7ccda2c8cbae9c7a90c28ff463d3e91f4"}, + {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, + {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, ] [package.dependencies] @@ -7975,29 +8003,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.6.4" +version = "0.6.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.4-py3-none-linux_armv6l.whl", hash = "sha256:c4b153fc152af51855458e79e835fb6b933032921756cec9af7d0ba2aa01a258"}, - {file = "ruff-0.6.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bedff9e4f004dad5f7f76a9d39c4ca98af526c9b1695068198b3bda8c085ef60"}, - {file = "ruff-0.6.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d02a4127a86de23002e694d7ff19f905c51e338c72d8e09b56bfb60e1681724f"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7862f42fc1a4aca1ea3ffe8a11f67819d183a5693b228f0bb3a531f5e40336fc"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eebe4ff1967c838a1a9618a5a59a3b0a00406f8d7eefee97c70411fefc353617"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:932063a03bac394866683e15710c25b8690ccdca1cf192b9a98260332ca93408"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:50e30b437cebef547bd5c3edf9ce81343e5dd7c737cb36ccb4fe83573f3d392e"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c44536df7b93a587de690e124b89bd47306fddd59398a0fb12afd6133c7b3818"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ea086601b22dc5e7693a78f3fcfc460cceabfdf3bdc36dc898792aba48fbad6"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b52387d3289ccd227b62102c24714ed75fbba0b16ecc69a923a37e3b5e0aaaa"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0308610470fcc82969082fc83c76c0d362f562e2f0cdab0586516f03a4e06ec6"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:803b96dea21795a6c9d5bfa9e96127cc9c31a1987802ca68f35e5c95aed3fc0d"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:66dbfea86b663baab8fcae56c59f190caba9398df1488164e2df53e216248baa"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:34d5efad480193c046c86608dbba2bccdc1c5fd11950fb271f8086e0c763a5d1"}, - {file = "ruff-0.6.4-py3-none-win32.whl", hash = "sha256:f0f8968feea5ce3777c0d8365653d5e91c40c31a81d95824ba61d871a11b8523"}, - {file = "ruff-0.6.4-py3-none-win_amd64.whl", hash = "sha256:549daccee5227282289390b0222d0fbee0275d1db6d514550d65420053021a58"}, - {file = "ruff-0.6.4-py3-none-win_arm64.whl", hash = "sha256:ac4b75e898ed189b3708c9ab3fc70b79a433219e1e87193b4f2b77251d058d14"}, - {file = "ruff-0.6.4.tar.gz", hash = "sha256:ac3b5bfbee99973f80aa1b7cbd1c9cbce200883bdd067300c22a6cc1c7fba212"}, + {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"}, + {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"}, + {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"}, + {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"}, + {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"}, + {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"}, + {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"}, ] [[package]] @@ -8194,13 +8222,13 @@ test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-ai [[package]] name = "sagemaker-core" -version = "1.0.2" +version = "1.0.4" description = "An python package for sagemaker core functionalities" optional = false python-versions = ">=3.8" files = [ - {file = "sagemaker_core-1.0.2-py3-none-any.whl", hash = "sha256:ce8d38a4a32efa83e4bc037a8befc7e29f87cd3eaf99acc4472b607f75a0f45a"}, - {file = "sagemaker_core-1.0.2.tar.gz", hash = "sha256:8fb942aac5e7ed928dab512ffe6facf8c6bdd4595df63c59c0bd0795ea434f8d"}, + {file = "sagemaker_core-1.0.4-py3-none-any.whl", hash = "sha256:bf71d988dbda03a3cd1557524f2fab4f19d89e54bd38fc7f05bbbcf580715f95"}, + {file = "sagemaker_core-1.0.4.tar.gz", hash = "sha256:203f4eb9d0d2a0e6ba80d79ba8c28b8ea27c94d04f6d9ff01c2fd55b95615c78"}, ] [package.dependencies] @@ -8229,32 +8257,32 @@ files = [ [[package]] name = "scikit-learn" -version = "1.5.1" +version = "1.5.2" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.9" files = [ - {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"}, - {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"}, - {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"}, - {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"}, - {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"}, - {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"}, - {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"}, - {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"}, - {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"}, - {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"}, - {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"}, - {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"}, - {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"}, - {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"}, - {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"}, - {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"}, - {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"}, - {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"}, - {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"}, - {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"}, - {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, ] [package.dependencies] @@ -8266,11 +8294,11 @@ threadpoolctl = ">=3.1.0" [package.extras] benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] maintenance = ["conda-lock (==2.5.6)"] -tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] [[package]] name = "scipy" @@ -8647,13 +8675,13 @@ doc = ["sphinx"] [[package]] name = "starlette" -version = "0.38.4" +version = "0.38.5" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.38.4-py3-none-any.whl", hash = "sha256:526f53a77f0e43b85f583438aee1a940fd84f8fd610353e8b0c1a77ad8a87e76"}, - {file = "starlette-0.38.4.tar.gz", hash = "sha256:53a7439060304a208fea17ed407e998f46da5e5d9b1addfea3040094512a6379"}, + {file = "starlette-0.38.5-py3-none-any.whl", hash = "sha256:632f420a9d13e3ee2a6f18f437b0a9f1faecb0bc42e1942aa2ea0e379a4c4206"}, + {file = "starlette-0.38.5.tar.gz", hash = "sha256:04a92830a9b6eb1442c766199d62260c3d4dc9c4f9188360626b1e0273cb7077"}, ] [package.dependencies] @@ -8750,13 +8778,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1226" +version = "3.0.1230" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1226.tar.gz", hash = "sha256:8e126cdce6adffce6fa5a3b464f0a6e483af7c7f78939883823393c2c5e8fc62"}, - {file = "tencentcloud_sdk_python_common-3.0.1226-py2.py3-none-any.whl", hash = "sha256:6165481280147afa226c6bb91df4cd0c43c5230f566be3d3f9c45a826b1105c5"}, + {file = "tencentcloud-sdk-python-common-3.0.1230.tar.gz", hash = "sha256:1e0f3bab80026fcb0083820869239b3f8cf30beb8e00e12c213bdecc75eb7577"}, + {file = "tencentcloud_sdk_python_common-3.0.1230-py2.py3-none-any.whl", hash = "sha256:03616c79685c154c689536a9c823d52b855cf49eada70679826a92aff5afd596"}, ] [package.dependencies] @@ -8764,17 +8792,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1226" +version = "3.0.1230" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1226.tar.gz", hash = "sha256:c9b9c3a373d967b691444bd590e3be1424aaab9f1ab30c57d98777113e2b7882"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1226-py2.py3-none-any.whl", hash = "sha256:87a1d63f85c25b5ec6c07f16d813091411ea6f296a1bf7fb608a529852b38bbe"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1230.tar.gz", hash = "sha256:900d15cb9dc2217b1282d985898ec7ecf97859351c86c6f7efc74685f08a5f85"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1230-py2.py3-none-any.whl", hash = "sha256:604dab0d4d66ea942f23d7980c76b5f0f6af3d68a8374e619331a4dd2910991e"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1226" +tencentcloud-sdk-python-common = "3.0.1230" [[package]] name = "threadpoolctl" @@ -9177,13 +9205,13 @@ typing-extensions = ">=3.7.4.3" [[package]] name = "types-requests" -version = "2.32.0.20240905" +version = "2.32.0.20240907" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.32.0.20240905.tar.gz", hash = "sha256:e97fd015a5ed982c9ddcd14cc4afba9d111e0e06b797c8f776d14602735e9bd6"}, - {file = "types_requests-2.32.0.20240905-py3-none-any.whl", hash = "sha256:f46ecb55f5e1a37a58be684cf3f013f166da27552732ef2469a0cc8e62a72881"}, + {file = "types-requests-2.32.0.20240907.tar.gz", hash = "sha256:ff33935f061b5e81ec87997e91050f7b4af4f82027a7a7a9d9aaea04a963fdf8"}, + {file = "types_requests-2.32.0.20240907-py3-none-any.whl", hash = "sha256:1d1e79faeaf9d42def77f3c304893dea17a97cae98168ac69f3cb465516ee8da"}, ] [package.dependencies] @@ -10388,4 +10416,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2dbff415c3c9ca95c8dcfb59fc088ce2c0d00037c44f386a34c87c98e1d8b942" +content-hash = "9173a56b2efea12804c980511e1465fba43c7a3d83b1ad284ee149851ed67fc5" diff --git a/api/pyproject.toml b/api/pyproject.toml index 69d1fc4ee0..166ddcec50 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,6 +6,9 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] +exclude=[ + "migrations/*", +] line-length = 120 [tool.ruff.lint] @@ -13,43 +16,61 @@ preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions + "E", # pycodestyle E rules "F", # pyflakes rules + "FURB", # refurb rules "I", # isort rules - "UP", # pyupgrade rules - "B035", # static-key-dict-comprehension - "E101", # mixed-spaces-and-tabs - "E111", # indentation-with-invalid-multiple - "E112", # no-indented-block - "E113", # unexpected-indentation - "E115", # no-indented-block-comment - "E116", # unexpected-indentation-comment - "E117", # over-indented + "N", # pep8-naming + "PT", # flake8-pytest-style rules + "PLC0208", # iteration-over-set + "PLC2801", # unnecessary-dunder-call + "PLC0414", # useless-import-alias + "PLR0402", # manual-from-import + "PLR1711", # useless-return + "PLR1714", # repeated-equality-comparison + "PLR6201", # literal-membership "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa "S506", # unsafe-yaml-load - "SIM116", # if-else-block-instead-of-dict-lookup - "SIM401", # if-else-block-instead-of-dict-get - "SIM910", # dict-get-with-none-default + "SIM", # flake8-simplify rules + "UP", # pyupgrade rules "W191", # tab-indentation "W605", # invalid-escape-sequence - "F601", # multi-value-repeated-key-literal - "F602", # multi-value-repeated-key-variable ] ignore = [ + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "E731", # lambda-assignment "F403", # undefined-local-with-import-star "F405", # undefined-local-with-import-star-usage "F821", # undefined-name "F841", # unused-variable + "FURB113", # repeated-append + "FURB152", # math-constant "UP007", # non-pep604-annotation "UP032", # f-string "B005", # strip-with-multi-characters "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg -# "B901", # return-in-generator "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # eumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "SIM300", # yoda-conditions, ] [tool.ruff.lint.per-file-ignores] @@ -61,6 +82,12 @@ ignore = [ "F401", # unused-import "F811", # redefined-while-unused ] +"configs/*" = [ + "N802", # invalid-function-name +] +"libs/gmpy2_pkcs10aep_cipher.py" = [ + "N803", # invalid-argument-name +] "tests/*" = [ "F401", # unused-import "F811", # redefined-while-unused @@ -68,9 +95,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "core/**/*.py", - "models/**/*.py", - "migrations/**/*", ] [tool.pytest_env] @@ -112,7 +136,7 @@ authlib = "1.3.1" azure-identity = "1.16.1" azure-storage-blob = "12.13.0" beautifulsoup4 = "4.12.2" -boto3 = "1.34.148" +boto3 = "1.35.17" sagemaker = "2.231.0" bs4 = "~0.0.1" cachetools = "~5.3.0" @@ -259,4 +283,4 @@ optional = true [tool.poetry.group.lint.dependencies] dotenv-linter = "~0.5.0" -ruff = "~0.6.1" +ruff = "~0.6.5" diff --git a/api/services/account_service.py b/api/services/account_service.py index e1b70fc9ed..66ff5d2b7c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -32,7 +32,7 @@ from services.errors.account import ( NoPermissionError, RateLimitExceededError, RoleAlreadyAssignedError, - TenantNotFound, + TenantNotFoundError, ) from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -47,7 +47,7 @@ class AccountService: if not account: return None - if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise Unauthorized("Account is banned or closed.") current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( @@ -92,7 +92,7 @@ class AccountService: if not account: raise AccountLoginError("Invalid email or password.") - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: @@ -311,13 +311,13 @@ class TenantService: """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: - raise TenantNotFound("Tenant not found.") + raise TenantNotFoundError("Tenant not found.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: - raise TenantNotFound("Tenant not found for the account.") + raise TenantNotFoundError("Tenant not found for the account.") return tenant @staticmethod @@ -427,7 +427,7 @@ class TenantService: "remove": [TenantAccountRole.OWNER], "update": [TenantAccountRole.OWNER], } - if action not in ["add", "remove", "update"]: + if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") if member: @@ -544,7 +544,7 @@ class RegisterService: """Register account""" try: account = AccountService.create_account( - email=email, name=name, interface_language=language if language else languages[0], password=password + email=email, name=name, interface_language=language or languages[0], password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -614,8 +614,8 @@ class RegisterService: "email": account.email, "workspace_id": tenant.id, } - expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) + expiry_hours = dify_config.INVITE_EXPIRY_HOURS + redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) return token @classmethod diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 73c446b83b..54594e1175 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -81,18 +81,16 @@ class AppDslService: raise ValueError("Missing app in data argument") # get app basic info - name = args.get("name") if args.get("name") else app_data.get("name") - description = args.get("description") if args.get("description") else app_data.get("description", "") - icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") - icon = args.get("icon") if args.get("icon") else app_data.get("icon") - icon_background = ( - args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") - ) + name = args.get("name") or app_data.get("name") + description = args.get("description") or app_data.get("description", "") + icon_type = args.get("icon_type") or app_data.get("icon_type") + icon = args.get("icon") or app_data.get("icon") + icon_background = args.get("icon_background") or app_data.get("icon_background") use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -105,7 +103,7 @@ class AppDslService: icon_background=icon_background, use_icon_as_answer_icon=use_icon_as_answer_icon, ) - elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -145,7 +143,7 @@ class AppDslService: # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: raise ValueError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: @@ -179,7 +177,7 @@ class AppDslService: }, } - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: cls._append_workflow_export_data( export_data=export_data, app_model=app_model, include_secret=include_secret ) @@ -240,7 +238,7 @@ class AppDslService: :param use_icon_as_answer_icon: use app icon as answer icon """ if not workflow_data: - raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") app = cls._create_app( tenant_id=tenant_id, @@ -285,7 +283,7 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -339,7 +337,7 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion") + raise ValueError("Missing model_config in data argument when app mode is chat, agent-chat or completion") app = cls._create_app( tenant_id=tenant_id, diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 83be6317dc..b13ed9718d 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -73,12 +73,12 @@ class AppGenerateService: return rate_limit.generate( AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming, + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, ), ), request_id, @@ -124,7 +124,7 @@ class AppGenerateService: ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( + return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_iteration_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) diff --git a/api/services/app_service.py b/api/services/app_service.py index 1dacfea246..ac45d623e8 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -316,7 +316,7 @@ class AppService: meta = {"tool_icons": {}} - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 05cd1c96a1..7a0cd5725b 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -83,7 +83,7 @@ class AudioService: def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): with app.app_context(): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("TTS is not enabled") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 30e4ee57c0..afc491398f 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -37,7 +37,7 @@ class FirecrawlAuth(ApiKeyAuthBase): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cce0874cf4..30c010ef29 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -155,7 +155,7 @@ class DatasetService: dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None - dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME + dataset.permission = permission or DatasetPermissionEnum.ONLY_ME db.session.add(dataset) db.session.commit() return dataset @@ -181,7 +181,7 @@ class DatasetService: "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") + raise ValueError(f"The dataset in unavailable, due to: {ex.description}") @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): @@ -195,10 +195,10 @@ class DatasetService: ) except LLMBadRequestError: raise ValueError( - "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") + raise ValueError(f"The dataset in unavailable, due to: {ex.description}") @staticmethod def update_dataset(dataset_id, data, user): @@ -544,7 +544,7 @@ class DocumentService: @staticmethod def pause_document(document): - if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: + if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: raise DocumentIndexingError() # update document to be paused document.is_paused = True @@ -681,11 +681,7 @@ class DocumentService: "score_threshold_enabled": False, } - dataset.retrieval_model = ( - document_data.get("retrieval_model") - if document_data.get("retrieval_model") - else default_retrieval_model - ) + dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model documents = [] batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) diff --git a/api/services/errors/account.py b/api/services/errors/account.py index cae31c5066..82dd9f944a 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -1,7 +1,7 @@ from services.errors.base import BaseServiceError -class AccountNotFound(BaseServiceError): +class AccountNotFoundError(BaseServiceError): pass @@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError): pass -class TenantNotFound(BaseServiceError): +class TenantNotFoundError(BaseServiceError): pass diff --git a/api/services/file_service.py b/api/services/file_service.py index 5780abb2be..bedec76334 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -56,9 +56,7 @@ class FileService: if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS ) - if extension.lower() not in allowed_extensions: - raise UnsupportedFileTypeError() - elif only_image and extension.lower() not in IMAGE_EXTENSIONS: + if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() # read file content diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 2f911f5036..3dafafd5b4 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -33,7 +33,7 @@ class HitTestingService: # get retrieval model , if the model is not setting , using default if not retrieval_model: - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model all_documents = RetrievalService.retrieve( retrieval_method=retrieval_model.get("search_method", "semantic_search"), @@ -42,13 +42,11 @@ class HitTestingService: top_k=retrieval_model.get("top_k", 2), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] - else None, + else 0.0, reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") - if retrieval_model.get("reranking_mode") - else "reranking_model", + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), ) diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index c0f3c40762..384a072b37 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,6 +1,7 @@ import logging import mimetypes import os +from pathlib import Path from typing import Optional, cast import requests @@ -453,9 +454,8 @@ class ModelProviderService: mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, "rb") as f: - byte_data = f.read() - return byte_data, mimetype + byte_data = Path(file_path).read_bytes() + return byte_data, mimetype def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: """ diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1e7935d299..1160a1f275 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -31,16 +31,28 @@ class OpsService: if tracing_provider == "langfuse" and ( "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") ): - project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update( - {"project_url": "{host}/project/{key}".format(host=decrypt_tracing_config.get("host"), key=project_key)} - ) + try: + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update( + { + "project_url": "{host}/project/{key}".format( + host=decrypt_tracing_config.get("host"), key=project_key + ) + } + ) + except Exception: + new_decrypt_tracing_config.update( + {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} + ) if tracing_provider == "langsmith" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): - project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update({"project_url": project_url}) + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @@ -54,7 +66,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider: + if tracing_provider not in provider_config_map and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} config_class, other_keys = ( @@ -113,7 +125,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys(): + if tracing_provider not in provider_config_map: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 10abf0a764..daec8393d0 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,6 +1,7 @@ import json import logging from os import path +from pathlib import Path from typing import Optional import requests @@ -218,10 +219,9 @@ class RecommendedAppService: return cls.builtin_data root_path = current_app.root_path - with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: - json_data = f.read() - data = json.loads(json_data) - cls.builtin_data = data + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") + ) return cls.builtin_data diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 0c17485a9f..5e2851cd8f 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -57,7 +57,7 @@ class TagService: .all() ) - return tags if tags else [] + return tags or [] @staticmethod def save_tags(args: dict) -> Tag: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 3e6e62a415..e1962305b9 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -180,7 +180,8 @@ class ApiToolManageService: get api tool provider remote schema """ headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", "Accept": "*/*", } diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index e202eeb15b..af1cdbacac 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,5 +1,6 @@ import json import logging +from pathlib import Path from configs import dify_config from core.helper.position_helper import is_filtered @@ -202,8 +203,7 @@ class BuiltinToolManageService: get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, "rb") as f: - icon_bytes = f.read() + icon_bytes = Path(icon_path).read_bytes() return icon_bytes, mime_type diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1e069c704b..552e2a0fbc 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -32,7 +32,7 @@ class ToolTransformService: if provider_type == ToolProviderType.BUILT_IN.value: return url_prefix + "builtin/" + provider_name + "/icon" - elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: + elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: if isinstance(icon, str): return json.loads(icon) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 9586e743d0..1544b39c23 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -265,7 +265,7 @@ class WorkflowToolManageService: .first() ) return cls._get_workflow_tool(db_tool) - + @classmethod def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None): """ @@ -276,11 +276,13 @@ class WorkflowToolManageService: if db_tool is None: raise ValueError("Tool not found") - workflow_app: App | None = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() + workflow_app: App | None = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") - + workflow = workflow_app.workflow if not workflow: raise ValueError("Workflow not found") @@ -324,7 +326,6 @@ class WorkflowToolManageService: return [ ToolTransformService.tool_to_user_tool( - tool=tool.get_tools(db_tool.tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ) ] diff --git a/api/services/website_service.py b/api/services/website_service.py index 6dff35d63f..fea605cf30 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -50,8 +50,8 @@ class WebsiteService: excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { "crawlerOptions": { - "includes": includes if includes else [], - "excludes": excludes if excludes else [], + "includes": includes or [], + "excludes": excludes or [], "generateImgAltText": True, "limit": options.get("limit", 1), "returnOnlyUrls": False, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 4b845be2f4..db1a036e68 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -63,11 +63,11 @@ class WorkflowConverter: # create new app new_app = App() new_app.tenant_id = app_model.tenant_id - new_app.name = name if name else app_model.name + "(workflow)" + new_app.name = name or app_model.name + "(workflow)" new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value - new_app.icon_type = icon_type if icon_type else app_model.icon_type - new_app.icon = icon if icon else app_model.icon - new_app.icon_background = icon_background if icon_background else app_model.icon_background + new_app.icon_type = icon_type or app_model.icon_type + new_app.icon = icon or app_model.icon + new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site new_app.enable_api = app_model.enable_api new_app.api_rpm = app_model.api_rpm diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 357ffd41c1..0ff81f1f7e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -295,7 +295,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 9ea4c99649..6dd755ab03 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info( click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e0da5f9ed0..72c4674e0f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 6e681bcf4f..cb38bc668d 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 0a7568c385..f4c3dbd2e2 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 18bae14ffa..934eb7430c 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Document @@ -29,7 +29,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() - if document.indexing_status in ["waiting", "parsing", "cleaning"]: + if document.indexing_status in {"waiting", "parsing", "cleaning"}: indexing_runner.run([document]) elif document.indexing_status == "splitting": indexing_runner.run_in_splitting_status(document) @@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logging.info( click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index bc0684086f..402bd9c2c2 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,15 +1,13 @@ from collections.abc import Generator -import google.generativeai.types.content_types as content_types import google.generativeai.types.generation_types as generation_config_types -import google.generativeai.types.safety_types as safety_types import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel from google.generativeai.client import _ClientManager, configure -from google.generativeai.types import GenerateContentResponse +from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types.generation_types import BaseGenerateContentResponse current_api_key = "" diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index b37b109eba..b9a721c803 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -51,7 +51,7 @@ class MockTEIClass: # } # } embeddings = [] - for idx, text in enumerate(texts): + for idx in range(len(texts)): embedding = [0.1] * 768 embeddings.append( { @@ -70,6 +70,7 @@ class MockTEIClass: }, } + @staticmethod def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: # Example response: # [ diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index d9cd7b046e..439f7d56e9 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -6,7 +6,6 @@ from time import time # import monkeypatch from typing import Any, Literal, Optional, Union -import openai.types.chat.completion_create_params as completion_create_params from openai import AzureOpenAI, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.chat.completions import Completions @@ -18,6 +17,7 @@ from openai.types.chat import ( ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, + completion_create_params, ) from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice @@ -254,7 +254,7 @@ class MockChatClass: "gpt-3.5-turbo-16k-0613", ] azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index c27e89248f..14223668e0 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -112,7 +112,7 @@ class MockCompletionsClass: ] azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 4138cdd40d..e27b9891f5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -22,7 +22,7 @@ class MockEmbeddingsClass: if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: @@ -40,7 +40,7 @@ class MockEmbeddingsClass: usage=Usage(prompt_tokens=2, total_tokens=2), ) - embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" + embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" # noqa: E501 data = [] for i, text in enumerate(input): diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 270a88e85f..4262d40f3e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -20,7 +20,7 @@ class MockModerationClass: if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index ef361e8613..a51dcab4be 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -20,7 +20,7 @@ class MockSpeech2TextClass: temperature: float | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> Transcription: - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 777737187e..299523f4f5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -42,7 +42,7 @@ class MockXinferenceClass: model_uid = url.split("/")[-1] or "" if not re.match( r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid - ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: + ) and model_uid not in {"generate", "chat", "embedding", "rerank"}: response.status_code = 404 response._content = b"{}" return response @@ -53,7 +53,7 @@ class MockXinferenceClass: response._content = b"{}" return response - if model_uid in ["generate", "chat"]: + if model_uid in {"generate", "chat"}: response.status_code = 200 response._content = b"""{ "model_type": "LLM", diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 7db59fddef..fb5e03855d 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -20,7 +20,7 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock -@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("chat", "none")], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -45,7 +45,7 @@ def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference ) -@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("chat", "none")], indirect=True) def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -75,7 +75,7 @@ def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): assert response.usage.total_tokens > 0 -@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("chat", "none")], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -236,7 +236,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # assert response.message.tool_calls[0].function.name == 'get_current_weather' -@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("completion", "none")], indirect=True) def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -261,7 +261,7 @@ def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinf ) -@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("completion", "none")], indirect=True) def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -286,7 +286,7 @@ def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): assert response.usage.total_tokens > 0 -@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("completion", "none")], indirect=True) def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 4dfc530010..d3c1f3101c 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 7d32f5ae66..50725415e4 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -9,7 +9,8 @@ def test_loading_subclass_from_source(): module = load_single_subclass_from_source( module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass ) - assert module and module.__name__ == "ChildClass" + assert module + assert module.__name__ == "ChildClass" def test_load_import_module_from_source(): @@ -17,7 +18,8 @@ def test_load_import_module_from_source(): module = import_module_from_source( module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py") ) - assert module and module.__name__ == "ChildClass" + assert module + assert module.__name__ == "ChildClass" def test_lazy_loading_subclass_from_source(): diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 571c1e3d44..53c9b3cae3 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,7 +13,7 @@ from xinference_client.types import Embedding class MockTcvectordbClass: - def VectorDBClient( + def mock_vector_db_client( self, url=None, username="", @@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index a99b81d41e..2666ce2e1e 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -34,7 +34,7 @@ class TestOpenSearchVector: self.vector._client = MagicMock() @pytest.mark.parametrize( - "search_response, expected_length, expected_doc_id", + ("search_response", "expected_length", "expected_doc_id"), [ ( { diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index cfc47bcad4..f1ab23b002 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index 44dcf9a10f..487178ff58 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -1,11 +1,11 @@ import pytest -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionException) as e: + with pytest.raises(CodeExecutionError) as e: CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index cbe4a5d335..25af312afa 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,4 +1,3 @@ -import json from textwrap import dedent from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index cbe9c5914f..88435c4022 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -411,5 +411,5 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): if latest_role is not None: assert latest_role != prompt.get("role") - if prompt.get("role") in ["user", "assistant"]: + if prompt.get("role") in {"user", "assistant"}: latest_role = prompt.get("role") diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index ca3082953a..621c995a4b 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -13,7 +13,7 @@ CACHED_APP = Flask(__name__) CACHED_APP.config.update({"TESTING": True}) -@pytest.fixture() +@pytest.fixture def app() -> Flask: return CACHED_APP diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 7cc339d212..73002623f0 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -21,9 +21,9 @@ def test_segment_group_to_text(): segments_group = parser.convert_template(template=template, variable_pool=variable_pool) assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." - assert ( - segments_group.log - == f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}." + assert segments_group.log == ( + f"Hello, fake-user-id! Your query is fake-user-query." + f" And your key is {encrypter.obfuscated_token('fake-secret-key')}." ) diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 7a0bc70c63..d6e6b0b79c 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,6 +1,8 @@ import random from unittest.mock import MagicMock, patch +import pytest + from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request @@ -22,11 +24,9 @@ def test_retry_exceed_max_retries(mock_request): side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES mock_request.side_effect = side_effects - try: + with pytest.raises(Exception) as e: make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) - raise AssertionError("Expected Exception not raised") - except Exception as e: - assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" + assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" @patch("httpx.request") diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 24bbde6d4e..24b338601d 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -53,7 +53,7 @@ def test__get_completion_model_prompt_messages(): "#context#": context, "#histories#": "\n".join( [ - f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}" + f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: {prompt.content}" for prompt in history_prompt_messages ] ), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py index 65757cd604..13ba11016a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -247,9 +247,9 @@ def test_parallels_graph(): for i in range(3): start_edges = graph.edge_mapping.get("start") assert start_edges is not None - assert start_edges[i].target_node_id == f"llm{i+1}" + assert start_edges[i].target_node_id == f"llm{i + 1}" - llm_edges = graph.edge_mapping.get(f"llm{i+1}") + llm_edges = graph.edge_mapping.get(f"llm{i + 1}") assert llm_edges is not None assert llm_edges[0].target_node_id == "answer" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index a2d71d61fc..197288adba 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -210,7 +210,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {"llm2", "llm3", "end1", "end2"}: assert item.parallel_id is not None assert len(items) == 18 @@ -315,12 +315,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in { "answer2", "answer3", "answer4", "answer5", - ]: + }: assert item.parallel_id is not None assert len(items) == 23 diff --git a/api/tests/unit_tests/libs/test_email.py b/api/tests/unit_tests/libs/test_email.py index f8234f3f3b..ae0177791b 100644 --- a/api/tests/unit_tests/libs/test_email.py +++ b/api/tests/unit_tests/libs/test_email.py @@ -1,3 +1,5 @@ +import pytest + from libs.helper import email @@ -9,17 +11,11 @@ def test_email_with_valid_email(): def test_email_with_invalid_email(): - try: + with pytest.raises(ValueError, match="invalid_email is not a valid email."): email("invalid_email") - except ValueError as e: - assert str(e) == "invalid_email is not a valid email." - try: + with pytest.raises(ValueError, match="@example.com is not a valid email."): email("@example.com") - except ValueError as e: - assert str(e) == "@example.com is not a valid email." - try: + with pytest.raises(ValueError, match="()@example.com is not a valid email."): email("()@example.com") - except ValueError as e: - assert str(e) == "()@example.com is not a valid email." diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 7075a31f2b..f8c5700cd9 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.8.0 + image: langgenius/dify-api:0.8.2 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.8.0 + image: langgenius/dify-api:0.8.2 restart: always environment: CONSOLE_WEB_URL: '' @@ -396,7 +396,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.8.0 + image: langgenius/dify-web:0.8.2 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index ca24c667f6..7e4430a37d 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -75,7 +75,7 @@ INIT_PASSWORD= DEPLOY_ENV=PRODUCTION # Whether to enable the version check policy. -# If set to empty, https://updates.dify.ai will not be called for version check. +# If set to empty, https://updates.dify.ai will be called for version check. CHECK_UPDATE_URL=https://updates.dify.ai # Used to change the OpenAI base address, default is https://api.openai.com/v1. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index dbfc1ea531..00faa2960a 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -41,7 +41,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.7 + image: langgenius/dify-sandbox:0.2.9 restart: always environment: # The DifySandbox configurations diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 68f897ddb9..d080731a28 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -47,12 +47,12 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} - REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-} + REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} - CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-} + CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} STORAGE_TYPE: ${STORAGE_TYPE:-local} @@ -208,7 +208,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.8.0 + image: langgenius/dify-api:0.8.2 restart: always environment: # Use the shared environment variables. @@ -228,7 +228,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.8.0 + image: langgenius/dify-api:0.8.2 restart: always environment: # Use the shared environment variables. @@ -247,7 +247,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.8.0 + image: langgenius/dify-web:0.8.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -292,7 +292,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.7 + image: langgenius/dify-sandbox:0.2.9 restart: always environment: # The DifySandbox configurations diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 2bcc74ec01..0558e29956 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -63,7 +63,6 @@ const AppPublisher = ({ const [published, setPublished] = useState(false) const [open, setOpen] = useState(false) const appDetail = useAppStore(state => state.appDetail) - const [publishedTime, setPublishedTime] = useState(publishedAt) const { app_base_url: appBaseURL = '', access_token: accessToken = '' } = appDetail?.site ?? {} const appMode = (appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow') ? 'chat' : appDetail.mode const appURL = `${appBaseURL}/${appMode}/${accessToken}` @@ -77,7 +76,6 @@ const AppPublisher = ({ try { await onPublish?.(modelAndParameter) setPublished(true) - setPublishedTime(Date.now()) } catch (e) { setPublished(false) @@ -133,13 +131,13 @@ const AppPublisher = ({
- {publishedTime ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')} + {publishedAt ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')}
- {publishedTime + {publishedAt ? (
- {t('workflow.common.publishedAt')} {formatTimeFromNow(publishedTime)} + {t('workflow.common.publishedAt')} {formatTimeFromNow(publishedAt)}
- }>{t('workflow.common.runApp')} + }>{t('workflow.common.runApp')} {appDetail?.mode === 'workflow' ? ( } > @@ -201,16 +199,16 @@ const AppPublisher = ({ setEmbeddingModalOpen(true) handleTrigger() }} - disabled={!publishedTime} + disabled={!publishedAt} icon={} > {t('workflow.common.embedIntoSite')} )} - }>{t('workflow.common.accessAPIReference')} + }>{t('workflow.common.accessAPIReference')} {appDetail?.mode === 'workflow' && ( = ({ questionIcon={} allToolIcons={allToolIcons} hideLogModal + noSpacing /> ) } diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index 892d0cfe8b..3d2f3bca59 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -128,6 +128,7 @@ const DebugWithMultipleModel = () => { onSend={handleSend} speechToTextConfig={speechToTextConfig} visionConfig={visionConfig} + noSpacing />
) diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index 0f09a23242..d93ad00659 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -130,6 +130,7 @@ const DebugWithSingleModel = forwardRef ) }) diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index c4578fab62..fdb09dc3ae 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -32,18 +32,21 @@ import { useDraggableUploader, useImageFiles, } from '@/app/components/base/image-uploader/hooks' +import cn from '@/utils/classnames' type ChatInputProps = { visionConfig?: VisionConfig speechToTextConfig?: EnableType onSend?: OnSend theme?: Theme | null + noSpacing?: boolean } const ChatInput: FC = ({ visionConfig, speechToTextConfig, onSend, theme, + noSpacing, }) => { const { appData } = useChatWithHistoryContext() const { t } = useTranslation() @@ -146,7 +149,7 @@ const ChatInput: FC = ({ return ( <> -
+
= ({ { visionConfig?.enabled && ( <> -
+
= ({ onDrop={onDrop} autoSize /> -
+
{query.trim().length}
diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 7540cd873b..65e49eff67 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -60,6 +60,7 @@ export type ChatProps = { hideProcessDetail?: boolean hideLogModal?: boolean themeBuilder?: ThemeBuilder + noSpacing?: boolean } const Chat: FC = ({ @@ -89,6 +90,7 @@ const Chat: FC = ({ hideProcessDetail, hideLogModal, themeBuilder, + noSpacing, }) => { const { t } = useTranslation() const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({ @@ -197,7 +199,7 @@ const Chat: FC = ({ {chatNode}
{ chatList.map((item, index) => { @@ -268,6 +270,7 @@ const Chat: FC = ({ speechToTextConfig={config?.speech_to_text} onSend={onSend} theme={themeBuilder?.theme} + noSpacing={noSpacing} /> ) } diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 11bcd84e18..d4e7dac4ae 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -198,11 +198,11 @@ const Paragraph = (paragraph: any) => { return ( <> -
{paragraph.children.slice(1)}
+

{paragraph.children.slice(1)}

) } - return
{paragraph.children}
+ return

{paragraph.children}

} const Img = ({ src }: any) => { diff --git a/web/app/components/develop/secret-key/secret-key-modal.tsx b/web/app/components/develop/secret-key/secret-key-modal.tsx index fd28a67e7e..dbb5cc37c7 100644 --- a/web/app/components/develop/secret-key/secret-key-modal.tsx +++ b/web/app/components/develop/secret-key/secret-key-modal.tsx @@ -41,7 +41,7 @@ const SecretKeyModal = ({ }: ISecretKeyModalProps) => { const { t } = useTranslation() const { formatTime } = useTimestamp() - const { currentWorkspace, isCurrentWorkspaceManager } = useAppContext() + const { currentWorkspace, isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() const [showConfirmDelete, setShowConfirmDelete] = useState(false) const [isVisible, setVisible] = useState(false) const [newKey, setNewKey] = useState(undefined) @@ -142,7 +142,7 @@ const SecretKeyModal = ({ ) }
- diff --git a/web/app/components/develop/template/template_workflow.en.mdx b/web/app/components/develop/template/template_workflow.en.mdx index 495b051bd0..2bd0fe9daf 100644 --- a/web/app/components/develop/template/template_workflow.en.mdx +++ b/web/app/components/develop/template/template_workflow.en.mdx @@ -413,3 +413,109 @@ Workflow applications offers non-session support and is ideal for translation, a + +--- + + + + + Returns worklfow logs, with the first page returning the latest `{limit}` messages, i.e., in reverse order. + + ### Query + + + + Keyword to search + + + succeeded/failed/stopped + + + current page, default is 1. + + + How many chat history messages to return in one request, default is 20. + + + + ### Response + - `page` (int) Current page + - `limit` (int) Number of returned items, if input exceeds system limit, returns system limit amount + - `total` (int) Number of total items + - `has_more` (bool) Whether there is a next page + - `data` (array[object]) Log list + - `id` (string) ID + - `workflow_run` (object) Workflow run + - `id` (string) ID + - `version` (string) Version + - `status` (string) status of execution, `running` / `succeeded` / `failed` / `stopped` + - `error` (string) Optional reason of error + - `elapsed_time` (float) total seconds to be used + - `total_tokens` (int) tokens to be used + - `total_steps` (int) default 0 + - `created_at` (timestamp) start time + - `finished_at` (timestamp) end time + - `created_from` (string) Created from + - `created_by_role` (string) Created by role + - `created_by_account` (string) Optional Created by account + - `created_by_end_user` (object) Created by end user + - `id` (string) ID + - `type` (string) Type + - `is_anonymous` (bool) Is anonymous + - `session_id` (string) Session ID + - `created_at` (timestamp) create time + + + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/workflows/logs?limit=1' + --header 'Authorization: Bearer {api_key}' + ``` + + + ### Response Example + + ```json {{ title: 'Response' }} + { + "page": 1, + "limit": 1, + "total": 7, + "has_more": true, + "data": [ + { + "id": "e41b93f1-7ca2-40fd-b3a8-999aeb499cc0", + "workflow_run": { + "id": "c0640fc8-03ef-4481-a96c-8a13b732a36e", + "version": "2024-08-01 12:17:09.771832", + "status": "succeeded", + "error": null, + "elapsed_time": 1.3588523610014818, + "total_tokens": 0, + "total_steps": 3, + "created_at": 1726139643, + "finished_at": 1726139644 + }, + "created_from": "service-api", + "created_by_role": "end_user", + "created_by_account": null, + "created_by_end_user": { + "id": "7f7d9117-dd9d-441d-8970-87e5e7e687a3", + "type": "service_api", + "is_anonymous": false, + "session_id": "abc-123" + }, + "created_at": 1726139644 + } + ] + } + ``` + + + diff --git a/web/app/components/develop/template/template_workflow.zh.mdx b/web/app/components/develop/template/template_workflow.zh.mdx index 640a4b3f92..d7d672fbd0 100644 --- a/web/app/components/develop/template/template_workflow.zh.mdx +++ b/web/app/components/develop/template/template_workflow.zh.mdx @@ -409,3 +409,109 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 + +--- + + + + + 倒序返回workflow日志 + + ### Query + + + + 关键字 + + + 执行状态 succeeded/failed/stopped + + + 当前页码, 默认1. + + + 每页条数, 默认20. + + + + ### Response + - `page` (int) 当前页码 + - `limit` (int) 每页条数 + - `total` (int) 总条数 + - `has_more` (bool) 是否还有更多数据 + - `data` (array[object]) 当前页码的数据 + - `id` (string) 标识 + - `workflow_run` (object) Workflow 执行日志 + - `id` (string) 标识 + - `version` (string) 版本 + - `status` (string) 执行状态, `running` / `succeeded` / `failed` / `stopped` + - `error` (string) (可选) 错误 + - `elapsed_time` (float) 耗时,单位秒 + - `total_tokens` (int) 消耗的token数量 + - `total_steps` (int) 执行步骤长度 + - `created_at` (timestamp) 开始时间 + - `finished_at` (timestamp) 结束时间 + - `created_from` (string) 来源 + - `created_by_role` (string) 角色 + - `created_by_account` (string) (可选) 帐号 + - `created_by_end_user` (object) 用户 + - `id` (string) 标识 + - `type` (string) 类型 + - `is_anonymous` (bool) 是否匿名 + - `session_id` (string) 会话标识 + - `created_at` (timestamp) 创建时间 + + + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/workflows/logs?limit=1' + --header 'Authorization: Bearer {api_key}' + ``` + + + ### Response Example + + ```json {{ title: 'Response' }} + { + "page": 1, + "limit": 1, + "total": 7, + "has_more": true, + "data": [ + { + "id": "e41b93f1-7ca2-40fd-b3a8-999aeb499cc0", + "workflow_run": { + "id": "c0640fc8-03ef-4481-a96c-8a13b732a36e", + "version": "2024-08-01 12:17:09.771832", + "status": "succeeded", + "error": null, + "elapsed_time": 1.3588523610014818, + "total_tokens": 0, + "total_steps": 3, + "created_at": 1726139643, + "finished_at": 1726139644 + }, + "created_from": "service-api", + "created_by_role": "end_user", + "created_by_account": null, + "created_by_end_user": { + "id": "7f7d9117-dd9d-441d-8970-87e5e7e687a3", + "type": "service_api", + "is_anonymous": false, + "session_id": "abc-123" + }, + "created_at": 1726139644 + } + ] + } + ``` + + + diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index e0b28a6d73..c0a7be68a6 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -70,15 +70,15 @@ const Form: FC = ({ const renderField = (formSchema: CredentialFormSchema) => { const tooltip = formSchema.tooltip const tooltipContent = (tooltip && ( - - - {tooltip[language] || tooltip.en_US} -
} - triggerClassName='w-4 h-4' - /> - )) + + {tooltip[language] || tooltip.en_US} +
} + triggerClassName='ml-1 w-4 h-4' + asChild={false} + /> + )) if (formSchema.type === FormTypeEnum.textInput || formSchema.type === FormTypeEnum.secretInput || formSchema.type === FormTypeEnum.textNumber) { const { variable, diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx index 86857f1ab2..1c318b9baf 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx @@ -192,12 +192,12 @@ const ModelLoadBalancingEntryModal: FC = ({ }) const getSecretValues = useCallback((v: FormValue) => { return secretFormSchemas.reduce((prev, next) => { - if (v[next.variable] === initialFormSchemasValue[next.variable]) + if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) prev[next.variable] = '[__HIDDEN__]' return prev }, {} as Record) - }, [initialFormSchemasValue, secretFormSchemas]) + }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { const handleValueChange = (v: FormValue) => { @@ -214,6 +214,7 @@ const ModelLoadBalancingEntryModal: FC = ({ ...value, ...getSecretValues(value), }, + entry?.id, ) if (res.status === ValidatedStatus.Success) { // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index 8cad399763..165926b2bb 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -56,14 +56,14 @@ export const validateCredentials = async (predefined: boolean, provider: string, } } -export const validateLoadBalancingCredentials = async (predefined: boolean, provider: string, v: FormValue): Promise<{ +export const validateLoadBalancingCredentials = async (predefined: boolean, provider: string, v: FormValue, id?: string): Promise<{ status: ValidatedStatus message?: string }> => { const { __model_name, __model_type, ...credentials } = v try { const res = await validateModelLoadBalancingCredentials({ - url: `/workspaces/current/model-providers/${provider}/models/load-balancing-configs/credentials-validate`, + url: `/workspaces/current/model-providers/${provider}/models/load-balancing-configs/${id ? `${id}/` : ''}credentials-validate`, body: { model: __model_name, model_type: __model_type, diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 96fe9f01ef..2d5546f9b4 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -219,7 +219,7 @@ const Result: FC = ({ })) }, onIterationNext: () => { - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true const iterations = draft.tracing.find(item => item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! diff --git a/web/app/components/workflow/hooks/use-shortcuts.ts b/web/app/components/workflow/hooks/use-shortcuts.ts index 666c3a45ba..8b1003e89c 100644 --- a/web/app/components/workflow/hooks/use-shortcuts.ts +++ b/web/app/components/workflow/hooks/use-shortcuts.ts @@ -70,14 +70,16 @@ export const useShortcuts = (): void => { }) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => { - if (shouldHandleShortcut(e)) { + const { showDebugAndPreviewPanel } = workflowStore.getState() + if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel) { e.preventDefault() handleNodesCopy() } }, { exactMatch: true, useCapture: true }) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.v`, (e) => { - if (shouldHandleShortcut(e)) { + const { showDebugAndPreviewPanel } = workflowStore.getState() + if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel) { e.preventDefault() handleNodesPaste() } @@ -98,7 +100,8 @@ export const useShortcuts = (): void => { }, { exactMatch: true, useCapture: true }) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.z`, (e) => { - if (shouldHandleShortcut(e)) { + const { showDebugAndPreviewPanel } = workflowStore.getState() + if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel) { e.preventDefault() workflowHistoryShortcutsEnabled && handleHistoryBack() } diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 460e36ae60..b201b28b88 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -283,15 +283,12 @@ export const useWorkflow = () => { return isUsed }, [isVarUsedInNodes]) - const checkParallelLimit = useCallback((nodeId: string) => { + const checkParallelLimit = useCallback((nodeId: string, nodeHandle = 'source') => { const { - getNodes, edges, } = store.getState() - const nodes = getNodes() - const currentNode = nodes.find(node => node.id === nodeId)! - const sourceNodeOutgoers = getOutgoers(currentNode, nodes, edges) - if (sourceNodeOutgoers.length > PARALLEL_LIMIT - 1) { + const connectedEdges = edges.filter(edge => edge.source === nodeId && edge.sourceHandle === nodeHandle) + if (connectedEdges.length > PARALLEL_LIMIT - 1) { const { setShowTips } = workflowStore.getState() setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT })) return false @@ -322,7 +319,7 @@ export const useWorkflow = () => { return true }, [t, workflowStore]) - const isValidConnection = useCallback(({ source, target }: Connection) => { + const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { edges, getNodes, @@ -331,7 +328,7 @@ export const useWorkflow = () => { const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! - if (!checkParallelLimit(source!)) + if (!checkParallelLimit(source!, sourceHandle || 'source')) return false if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx index 6e3988eecb..75694983cd 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx @@ -1,6 +1,7 @@ import { memo, useCallback, + useState, } from 'react' import { useTranslation } from 'react-i18next' import { @@ -10,6 +11,7 @@ import { useAvailableBlocks, useNodesInteractions, useNodesReadOnly, + useWorkflow, } from '@/app/components/workflow/hooks' import BlockSelector from '@/app/components/workflow/block-selector' import type { @@ -30,9 +32,11 @@ const Add = ({ isParallel, }: AddProps) => { const { t } = useTranslation() + const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() const { nodesReadOnly } = useNodesReadOnly() const { availableNextBlocks } = useAvailableBlocks(nodeData.type, nodeData.isInIteration) + const { checkParallelLimit } = useWorkflow() const handleSelect = useCallback((type, toolDefaultValue) => { handleNodeAdd( @@ -47,6 +51,13 @@ const Add = ({ ) }, [nodeId, sourceHandle, handleNodeAdd]) + const handleOpenChange = useCallback((newOpen: boolean) => { + if (newOpen && !checkParallelLimit(nodeId, sourceHandle)) + return + + setOpen(newOpen) + }, [checkParallelLimit, nodeId, sourceHandle]) + const renderTrigger = useCallback((open: boolean) => { return (
{ e.stopPropagation() - if (checkParallelLimit(id)) + if (checkParallelLimit(id, handleId)) setOpen(v => !v) - }, [checkParallelLimit, id]) + }, [checkParallelLimit, id, handleId]) const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { handleNodeAdd( { diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx index e2b1f0a31c..7fb4ad68d8 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx @@ -8,6 +8,7 @@ import { } from '@remixicon/react' import produce from 'immer' import { useStoreApi } from 'reactflow' +import useAvailableVarList from '../../hooks/use-available-var-list' import VarReferencePopup from './var-reference-popup' import { getNodeInfoById, isConversationVar, isENV, isSystemVar } from './utils' import ConstantField from './constant-field' @@ -26,7 +27,6 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { useIsChatMode, - useWorkflow, useWorkflowVariables, } from '@/app/components/workflow/hooks' import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' @@ -67,7 +67,7 @@ const VarReferencePicker: FC = ({ onlyLeafNodeVar, filterVar = () => true, availableNodes: passedInAvailableNodes, - availableVars, + availableVars: passedInAvailableVars, isAddBtnTrigger, schema, valueTypePlaceHolder, @@ -79,11 +79,12 @@ const VarReferencePicker: FC = ({ } = store.getState() const isChatMode = useIsChatMode() - const { getTreeLeafNodes, getBeforeNodesInSameBranch } = useWorkflow() - const { getCurrentVariableType, getNodeAvailableVars } = useWorkflowVariables() - const availableNodes = useMemo(() => { - return passedInAvailableNodes || (onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranch(nodeId)) - }, [getBeforeNodesInSameBranch, getTreeLeafNodes, nodeId, onlyLeafNodeVar, passedInAvailableNodes]) + const { getCurrentVariableType } = useWorkflowVariables() + const { availableNodes, availableVars } = useAvailableVarList(nodeId, { + onlyLeafNodeVar, + passedInAvailableNodes, + filterVar, + }) const startNode = availableNodes.find((node: any) => { return node.data.type === BlockEnum.Start }) @@ -102,19 +103,8 @@ const VarReferencePicker: FC = ({ const [varKindType, setVarKindType] = useState(defaultVarKindType) const isConstant = isSupportConstantValue && varKindType === VarKindType.constant - const outputVars = useMemo(() => { - if (availableVars) - return availableVars - const vars = getNodeAvailableVars({ - parentNode: iterationNode, - beforeNodes: availableNodes, - isChatMode, - filterVar, - }) - - return vars - }, [iterationNode, availableNodes, isChatMode, filterVar, availableVars, getNodeAvailableVars]) + const outputVars = useMemo(() => (passedInAvailableVars || availableVars), [passedInAvailableVars, availableVars]) const [open, setOpen] = useState(false) useEffect(() => { diff --git a/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts b/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts index b81feab805..bd17bb1de8 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts @@ -4,12 +4,13 @@ import { useWorkflow, useWorkflowVariables, } from '@/app/components/workflow/hooks' -import type { ValueSelector, Var } from '@/app/components/workflow/types' +import type { Node, ValueSelector, Var } from '@/app/components/workflow/types' type Params = { onlyLeafNodeVar?: boolean hideEnv?: boolean hideChatVar?: boolean filterVar: (payload: Var, selector: ValueSelector) => boolean + passedInAvailableNodes?: Node[] } const useAvailableVarList = (nodeId: string, { @@ -17,6 +18,7 @@ const useAvailableVarList = (nodeId: string, { filterVar, hideEnv, hideChatVar, + passedInAvailableNodes, }: Params = { onlyLeafNodeVar: false, filterVar: () => true, @@ -25,7 +27,7 @@ const useAvailableVarList = (nodeId: string, { const { getNodeAvailableVars } = useWorkflowVariables() const isChatMode = useIsChatMode() - const availableNodes = onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranch(nodeId) + const availableNodes = passedInAvailableNodes || (onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranch(nodeId)) const { parentNode: iterationNode, diff --git a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx index e7639ddc79..a7dd607e22 100644 --- a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx @@ -118,6 +118,7 @@ const ChatWrapper = forwardRef(({ showConv } )} + noSpacing suggestedQuestions={suggestedQuestions} showPromptLog chatAnswerContainerInner='!pr-2' diff --git a/web/app/components/workflow/run/meta.tsx b/web/app/components/workflow/run/meta.tsx index 86eb221ad9..b2d7269a51 100644 --- a/web/app/components/workflow/run/meta.tsx +++ b/web/app/components/workflow/run/meta.tsx @@ -16,7 +16,7 @@ type Props = { const MetaData: FC = ({ status, executor, - startTime = 0, + startTime, time, tokens, steps = 1, @@ -64,7 +64,7 @@ const MetaData: FC = ({
)} {status !== 'running' && ( - {formatTime(startTime, t('appLog.dateTimeFormat') as string)} + {startTime ? formatTime(startTime, t('appLog.dateTimeFormat') as string) : '-'} )}
@@ -75,7 +75,7 @@ const MetaData: FC = ({
)} {status !== 'running' && ( - {`${time?.toFixed(3)}s`} + {time ? `${time.toFixed(3)}s` : '-'} )}
diff --git a/web/app/components/workflow/run/status.tsx b/web/app/components/workflow/run/status.tsx index 2eeafca95d..0677e43401 100644 --- a/web/app/components/workflow/run/status.tsx +++ b/web/app/components/workflow/run/status.tsx @@ -71,7 +71,7 @@ const StatusPanel: FC = ({
)} {status !== 'running' && ( - {`${time?.toFixed(3)}s`} + {time ? `${time?.toFixed(3)}s` : '-'} )}
diff --git a/web/app/styles/markdown.scss b/web/app/styles/markdown.scss index b0a1f60cd2..214d8d2782 100644 --- a/web/app/styles/markdown.scss +++ b/web/app/styles/markdown.scss @@ -321,18 +321,12 @@ .markdown-body h4, .markdown-body h5, .markdown-body h6 { - margin-top: 24px; - margin-bottom: 16px; + padding-top: 12px; + margin-bottom: 12px; font-weight: var(--base-text-weight-semibold, 600); line-height: 1.25; } - -.markdown-body p { - margin-top: 0; - margin-bottom: 10px; -} - .markdown-body blockquote { margin: 0; padding: 0 8px; @@ -449,7 +443,7 @@ .markdown-body pre, .markdown-body details { margin-top: 0; - margin-bottom: 16px; + margin-bottom: 12px; } .markdown-body blockquote> :first-child { diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index 2762c01a8d..5e154aeca5 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -77,6 +77,22 @@ const translation = { importDSLTip: 'Der aktuelle Entwurf wird überschrieben. Exportieren Sie den Workflow vor dem Import als Backup.', overwriteAndImport: 'Überschreiben und Importieren', backupCurrentDraft: 'Aktuellen Entwurf sichern', + parallelTip: { + click: { + title: 'Klicken', + desc: 'hinzuzufügen', + }, + drag: { + title: 'Ziehen', + desc: 'um eine Verbindung herzustellen', + }, + limit: 'Die Parallelität ist auf {{num}} Zweige beschränkt.', + depthLimit: 'Begrenzung der parallelen Verschachtelungsschicht von {{num}} Schichten', + }, + parallelRun: 'Paralleler Lauf', + disconnect: 'Trennen', + jumpToNode: 'Zu diesem Knoten springen', + addParallelNode: 'Parallelen Knoten hinzufügen', }, env: { envPanelTitle: 'Umgebungsvariablen', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 7db2b48687..0efc996f91 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Sobrescribir e importar', importFailure: 'Error al importar', importSuccess: 'Importación exitosa', + parallelTip: { + click: { + title: 'Clic', + desc: 'Para agregar', + }, + drag: { + title: 'Arrastrar', + desc: 'Para conectarse', + }, + limit: 'El paralelismo se limita a {{num}} ramas.', + depthLimit: 'Límite de capa de anidamiento paralelo de capas {{num}}', + }, + parallelRun: 'Ejecución paralela', + disconnect: 'Desconectar', + jumpToNode: 'Saltar a este nodo', + addParallelNode: 'Agregar nodo paralelo', }, env: { envPanelTitle: 'Variables de Entorno', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index da2d439a36..67020f5025 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'بازنویسی و وارد کردن', importFailure: 'خطا در وارد کردن', importSuccess: 'وارد کردن موفقیت‌آمیز', + parallelTip: { + click: { + title: 'کلیک کنید', + desc: 'اضافه کردن', + }, + drag: { + desc: 'برای اتصال', + title: 'کشیدن', + }, + depthLimit: 'حد لایه تودرتو موازی لایه های {{num}}', + limit: 'موازی سازی به شاخه های {{num}} محدود می شود.', + }, + disconnect: 'قطع', + jumpToNode: 'پرش به این گره', + parallelRun: 'اجرای موازی', + addParallelNode: 'افزودن گره موازی', }, env: { envPanelTitle: 'متغیرهای محیطی', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index 8691f6f0d5..3e56246e0c 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Écraser et importer', importFailure: 'Echec de l\'importation', importSuccess: 'Import avec succès', + parallelTip: { + click: { + title: 'Cliquer', + desc: 'à ajouter', + }, + drag: { + title: 'Traîner', + desc: 'pour se connecter', + }, + limit: 'Le parallélisme est limité aux branches {{num}}.', + depthLimit: 'Limite de couches d’imbrication parallèle de {{num}} couches', + }, + parallelRun: 'Exécution parallèle', + disconnect: 'Déconnecter', + jumpToNode: 'Aller à ce nœud', + addParallelNode: 'Ajouter un nœud parallèle', }, env: { envPanelTitle: 'Variables d\'Environnement', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 96c232d72f..072e4874e3 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -80,6 +80,22 @@ const translation = { backupCurrentDraft: 'बैकअप वर्तमान ड्राफ्ट', importFailure: 'आयात विफलता', importDSLTip: 'वर्तमान ड्राफ्ट ओवरराइट हो जाएगा। आयात करने से पहले वर्कफ़्लो को बैकअप के रूप में निर्यात करें.', + parallelTip: { + click: { + title: 'क्लिक करना', + desc: 'जोड़ने के लिए', + }, + drag: { + title: 'खींचना', + desc: 'कनेक्ट करने के लिए', + }, + limit: 'समांतरता {{num}} शाखाओं तक सीमित है।', + depthLimit: '{{num}} परतों की समानांतर नेस्टिंग परत सीमा', + }, + disconnect: 'अलग करना', + parallelRun: 'समानांतर रन', + jumpToNode: 'इस नोड पर जाएं', + addParallelNode: 'समानांतर नोड जोड़ें', }, env: { envPanelTitle: 'पर्यावरण चर', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 62ce0c5677..f5d6fc8bf5 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -81,6 +81,22 @@ const translation = { overwriteAndImport: 'Sovrascrivi e Importa', importFailure: 'Importazione fallita', importSuccess: 'Importazione riuscita', + parallelTip: { + click: { + title: 'Clic', + desc: 'per aggiungere', + }, + drag: { + title: 'Trascinare', + desc: 'per collegare', + }, + depthLimit: 'Limite di livelli di annidamento parallelo di {{num}} livelli', + limit: 'Il parallelismo è limitato ai rami {{num}}.', + }, + parallelRun: 'Corsa parallela', + disconnect: 'Disconnettere', + jumpToNode: 'Vai a questo nodo', + addParallelNode: 'Aggiungi nodo parallelo', }, env: { envPanelTitle: 'Variabili d\'Ambiente', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index d244e36bcd..755061e8f6 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'オーバライトとインポート', importFailure: 'インポート失敗', importSuccess: 'インポート成功', + parallelTip: { + click: { + title: 'クリック', + desc: '追加する', + }, + drag: { + title: 'ドラッグ', + desc: '接続するには', + }, + limit: '並列処理は {{num}} ブランチに制限されています。', + depthLimit: '{{num}}レイヤーの平行ネストレイヤーの制限', + }, + parallelRun: 'パラレルラン', + disconnect: '切る', + jumpToNode: 'このノードにジャンプします', + addParallelNode: '並列ノードを追加', }, env: { envPanelTitle: '環境変数', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index 7c85d02c3d..8fed0e0417 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -77,6 +77,22 @@ const translation = { importSuccess: '가져오기 성공', syncingData: '단 몇 초 만에 데이터를 동기화할 수 있습니다.', importDSLTip: '현재 초안을 덮어씁니다. 가져오기 전에 워크플로를 백업으로 내보냅니다.', + parallelTip: { + click: { + title: '클릭', + desc: '추가', + }, + drag: { + title: '드래그', + desc: '연결 방법', + }, + depthLimit: '평행 중첩 레이어 {{num}}개 레이어의 제한', + limit: '병렬 처리는 {{num}}개의 분기로 제한됩니다.', + }, + parallelRun: '병렬 실행', + disconnect: '분리하다', + jumpToNode: '이 노드로 이동', + addParallelNode: '병렬 노드 추가', }, env: { envPanelTitle: '환경 변수', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 09b9cea1be..de05ee7169 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -77,6 +77,22 @@ const translation = { chooseDSL: 'Wybierz plik DSL(yml)', backupCurrentDraft: 'Utwórz kopię zapasową bieżącej wersji roboczej', importFailure: 'Niepowodzenie importu', + parallelTip: { + click: { + title: 'Klikać', + desc: ', aby dodać', + }, + drag: { + title: 'Przeciągnąć', + desc: 'aby się połączyć', + }, + limit: 'Równoległość jest ograniczona do gałęzi {{num}}.', + depthLimit: 'Limit warstw zagnieżdżania równoległego dla warstw {{num}}', + }, + parallelRun: 'Bieg równoległy', + jumpToNode: 'Przejdź do tego węzła', + disconnect: 'Odłączyć', + addParallelNode: 'Dodaj węzeł równoległy', }, env: { envPanelTitle: 'Zmienne Środowiskowe', diff --git a/web/i18n/pt-BR/billing.ts b/web/i18n/pt-BR/billing.ts index 6eaeaa0fb2..0a7a964376 100644 --- a/web/i18n/pt-BR/billing.ts +++ b/web/i18n/pt-BR/billing.ts @@ -17,7 +17,7 @@ const translation = { }, month: 'mês', year: 'ano', - save: 'Salvar', + save: 'Economize ', free: 'Grátis', currentPlan: 'Plano Atual', contractOwner: 'Entre em contato com o gerente da equipe', @@ -57,7 +57,7 @@ const translation = { workflow: 'Fluxo de trabalho', llmLoadingBalancing: 'Balanceamento de carga LLM', bulkUpload: 'Upload em massa de documentos', - llmLoadingBalancingTooltip: 'Adicione várias chaves de API aos modelos, efetivamente ignorando os limites de taxa da API.', + llmLoadingBalancingTooltip: 'Adicione várias chaves de API aos modelos, efetivamente ignorando os limites de taxa da API. ', agentMode: 'Modo Agente', }, comingSoon: 'Em breve', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index da41268302..ef589c0bde 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -77,6 +77,22 @@ const translation = { importDSLTip: 'O rascunho atual será substituído. Exporte o fluxo de trabalho como backup antes de importar.', backupCurrentDraft: 'Fazer backup do rascunho atual', importDSL: 'Importar DSL', + parallelTip: { + click: { + title: 'Clique', + desc: 'para adicionar', + }, + drag: { + title: 'Arrastar', + desc: 'para conectar', + }, + limit: 'O paralelismo é limitado a {{num}} ramificações.', + depthLimit: 'Limite de camada de aninhamento paralelo de {{num}} camadas', + }, + parallelRun: 'Execução paralela', + disconnect: 'Desligar', + jumpToNode: 'Ir para este nó', + addParallelNode: 'Adicionar nó paralelo', }, env: { envPanelTitle: 'Variáveis de Ambiente', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index f8177aeb06..689ebdead9 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -77,6 +77,22 @@ const translation = { importSuccess: 'Succesul importului', backupCurrentDraft: 'Backup curent draft', importDSLTip: 'Proiectul curent va fi suprascris. Exportați fluxul de lucru ca backup înainte de import.', + parallelTip: { + click: { + title: 'Clic', + desc: 'pentru a adăuga', + }, + drag: { + title: 'Glisa', + desc: 'pentru a vă conecta', + }, + depthLimit: 'Limita straturilor de imbricare paralelă a {{num}} straturi', + limit: 'Paralelismul este limitat la {{num}} ramuri.', + }, + parallelRun: 'Rulare paralelă', + disconnect: 'Deconecta', + jumpToNode: 'Sari la acest nod', + addParallelNode: 'Adăugare nod paralel', }, env: { envPanelTitle: 'Variabile de Mediu', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index b31b798bd3..9d3ce1235c 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Перезаписать и импортировать', importFailure: 'Ошибка импорта', importSuccess: 'Импорт успешно завершен', + parallelTip: { + click: { + title: 'Щелчок', + desc: 'добавить', + }, + drag: { + title: 'Волочить', + desc: 'для подключения', + }, + limit: 'Параллелизм ограничен ветвями {{num}}.', + depthLimit: 'Ограничение на количество слоев параллельной вложенности {{num}}', + }, + parallelRun: 'Параллельный прогон', + disconnect: 'Разъединять', + jumpToNode: 'Перейти к этому узлу', + addParallelNode: 'Добавить параллельный узел', }, env: { envPanelTitle: 'Переменные среды', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 856c5d518f..96313d6d6b 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Üzerine Yaz ve İçe Aktar', importFailure: 'İçe Aktarma Başarısız', importSuccess: 'İçe Aktarma Başarılı', + parallelTip: { + click: { + desc: 'Eklemek için', + title: 'Tık', + }, + drag: { + title: 'Sürükleme', + desc: 'Bağlanmak için', + }, + depthLimit: '{{num}} katmanlarının paralel iç içe geçme katmanı sınırı', + limit: 'Paralellik {{num}} dallarıyla sınırlıdır.', + }, + jumpToNode: 'Bu düğüme atla', + addParallelNode: 'Paralel Düğüm Ekle', + disconnect: 'Ayırmak', + parallelRun: 'Paralel Koşu', }, env: { envPanelTitle: 'Çevre Değişkenleri', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index c6b98d1b5b..03471348c8 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -77,6 +77,22 @@ const translation = { chooseDSL: 'Виберіть файл DSL(yml)', backupCurrentDraft: 'Резервна поточна чернетка', importDSLTip: 'Поточна чернетка буде перезаписана. Експортуйте робочий процес як резервну копію перед імпортом.', + parallelTip: { + click: { + title: 'Натисніть', + desc: 'щоб додати', + }, + drag: { + title: 'Перетягувати', + desc: 'Щоб підключити', + }, + limit: 'Паралелізм обмежується {{num}} гілками.', + depthLimit: 'Обмеження рівня паралельного вкладеності шарів {{num}}', + }, + disconnect: 'Відключити', + parallelRun: 'Паралельний біг', + jumpToNode: 'Перейти до цього вузла', + addParallelNode: 'Додати паралельний вузол', }, env: { envPanelTitle: 'Змінні середовища', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index d850d1a732..5be19ab7fd 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Ghi đè và nhập', importDSL: 'Nhập DSL', syncingData: 'Đồng bộ hóa dữ liệu, chỉ vài giây.', + parallelTip: { + click: { + title: 'Bấm', + desc: 'để thêm', + }, + drag: { + title: 'Kéo', + desc: 'Để kết nối', + }, + limit: 'Song song được giới hạn trong các nhánh {{num}}.', + depthLimit: 'Giới hạn lớp lồng song song của {{num}} layer', + }, + parallelRun: 'Chạy song song', + disconnect: 'Ngắt kết nối', + jumpToNode: 'Chuyển đến nút này', + addParallelNode: 'Thêm nút song song', }, env: { envPanelTitle: 'Biến Môi Trường', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 618fa8ce2e..eef3ffaebd 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -77,6 +77,22 @@ const translation = { syncingData: '同步數據,只需幾秒鐘。', importDSLTip: '當前草稿將被覆蓋。在導入之前將工作流匯出為備份。', importFailure: '匯入失敗', + parallelTip: { + click: { + title: '點擊', + desc: '添加', + }, + drag: { + title: '拖动', + desc: '連接', + }, + limit: '並行度僅限於 {{num}} 個分支。', + depthLimit: '並行嵌套層限制為 {{num}} 個層', + }, + parallelRun: '並行運行', + disconnect: '斷開', + jumpToNode: '跳轉到此節點', + addParallelNode: '添加並行節點', }, env: { envPanelTitle: '環境變數', diff --git a/web/package.json b/web/package.json index 374286f8f7..197a1e7e05 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.8.0", + "version": "0.8.2", "private": true, "engines": { "node": ">=18.17.0"