mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
Merge main
This commit is contained in:
commit
00d1c45518
|
@ -164,7 +164,7 @@ def initialize_extensions(app):
|
|||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in ["console", "inner_api"]:
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
|
|
@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
|
|||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
|
||||
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
|
@ -131,7 +131,7 @@ def reset_encrypt_key_pair():
|
|||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
|
|||
@click.command("vdb-migrate", help="migrate vector db.")
|
||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ["knowledge", "all"]:
|
||||
if scope in {"knowledge", "all"}:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ["annotation", "all"]:
|
||||
if scope in {"annotation", "all"}:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
|
@ -275,8 +275,7 @@ def migrate_knowledge_vector_database():
|
|||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(
|
||||
f"Processing the {total_count} dataset {dataset.id}. "
|
||||
+ f"{create_count} created, {skipped_count} skipped."
|
||||
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo("Create dataset vdb index: {}".format(dataset.id))
|
||||
|
@ -411,7 +410,8 @@ def migrate_knowledge_vector_database():
|
|||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count}"
|
||||
f" segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
@ -593,7 +593,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
|||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
|
||||
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
|
|
@ -110,6 +110,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
|
@ -124,6 +125,7 @@ class PluginConfig(BaseSettings):
|
|||
default='dify-inner-api-key',
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Module URL configs
|
||||
|
@ -142,12 +144,12 @@ class EndpointConfig(BaseSettings):
|
|||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
|
||||
description="Service API Url prefix. used to display Service API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
|
||||
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
@ -285,7 +287,7 @@ class LoggingConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
LOG_LEVEL: str = Field(
|
||||
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
|
||||
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.8.0",
|
||||
default="0.8.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
|
|||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] if args["desc"] else ""
|
||||
copy_right = args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
else:
|
||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = (
|
||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
)
|
||||
custom_disclaimer = (
|
||||
site.custom_disclaimer
|
||||
if site.custom_disclaimer
|
||||
else args["custom_disclaimer"]
|
||||
if args["custom_disclaimer"]
|
||||
else ""
|
||||
)
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
|
|
@ -94,19 +94,15 @@ class ChatMessageTextApi(Resource):
|
|||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
|
|
|
@ -20,7 +20,7 @@ from fields.conversation_fields import (
|
|||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
|
||||
|
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
|
|||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
|
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
|
|||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
|
@ -25,14 +25,17 @@ class DailyMessageStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -45,7 +48,7 @@ class DailyMessageStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -55,10 +58,10 @@ class DailyMessageStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -79,14 +82,17 @@ class DailyConversationStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -99,7 +105,7 @@ class DailyConversationStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -109,10 +115,10 @@ class DailyConversationStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -133,14 +139,17 @@ class DailyTerminalsStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -153,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -163,10 +172,10 @@ class DailyTerminalsStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -187,16 +196,18 @@ class DailyTokenCostStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
||||
sum(total_price) as total_price
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -209,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -219,10 +230,10 @@ class DailyTokenCostStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -245,16 +256,26 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
FROM conversations c
|
||||
JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
m.conversation_id,
|
||||
COUNT(m.id) AS message_count
|
||||
FROM
|
||||
conversations c
|
||||
JOIN
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.override_model_configs IS NULL
|
||||
AND c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -267,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at >= :start"
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -277,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at < :end"
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += """
|
||||
GROUP BY m.conversation_id) subquery
|
||||
LEFT JOIN conversations c on c.id=subquery.conversation_id
|
||||
GROUP BY date
|
||||
ORDER BY date"""
|
||||
GROUP BY m.conversation_id
|
||||
) subquery
|
||||
LEFT JOIN
|
||||
conversations c
|
||||
ON c.id = subquery.conversation_id
|
||||
GROUP BY
|
||||
date
|
||||
ORDER BY
|
||||
date"""
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -307,17 +333,21 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||
FROM messages m
|
||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
||||
WHERE m.app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
messages m
|
||||
LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -330,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at >= :start"
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -340,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at < :end"
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -369,16 +399,17 @@ class AverageResponseTimeStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) as latency
|
||||
FROM messages
|
||||
WHERE app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -391,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -401,10 +432,10 @@ class AverageResponseTimeStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -425,17 +456,20 @@ class TokensPerSecondStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
END as tokens_per_second
|
||||
FROM messages
|
||||
WHERE app_id = :app_id"""
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -448,7 +482,7 @@ WHERE app_id = :app_id"""
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -458,10 +492,10 @@ WHERE app_id = :app_id"""
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
|
|
@ -502,6 +502,6 @@ api.add_resource(
|
|||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
|
|
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
|
@ -26,16 +26,18 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
|
@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -86,16 +88,18 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
|
@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -146,18 +150,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) as token_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
|
@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
|
@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
@ -213,27 +217,31 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
AVG(sub.interactions) as interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(SELECT
|
||||
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM workflow_runs c
|
||||
WHERE c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY date, c.created_by) sub
|
||||
GROUP BY sub.date
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
|
@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
|
|
@ -8,7 +8,7 @@ from constants.languages import supported_language
|
|||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, timezone
|
||||
from libs.helper import StrLen, email, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
|
@ -37,7 +37,7 @@ class ActivateApi(Resource):
|
|||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
|
|
|
@ -71,7 +71,7 @@ class OAuthCallback(Resource):
|
|||
|
||||
account = _generate_account(provider, user_info)
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
|
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
|
||||
if not account:
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else "Dify"
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
|
|
@ -399,7 +399,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
"api_base_url": (
|
||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||
)
|
||||
+ "/v1"
|
||||
}
|
||||
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
|
||||
|
||||
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
|
|
|
@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in ["completed", "error"]:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
|
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ["completed", "error"]:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
# format document files info
|
||||
|
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in ["paused", "error"]:
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
|
|
|
@ -18,9 +18,7 @@ class NotSetupError(BaseHTTPException):
|
|||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = "not_init_validated"
|
||||
description = (
|
||||
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
|
||||
)
|
||||
description = "Init validation has not been completed yet. Please proceed with the init validation process first."
|
||||
code = 401
|
||||
|
||||
|
||||
|
|
|
@ -81,19 +81,15 @@ class ChatTextApi(InstalledAppResource):
|
|||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
|
|
|
@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
|
|||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
|
|||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
|
|
@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
|
|||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
|
|||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
|
|||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
|
|||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
|
|
@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
|
|||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||
"is_pinned": installed_app.is_pinned,
|
||||
"last_used_at": installed_app.last_used_at,
|
||||
"editable": current_user.role in ["owner", "admin"],
|
||||
"editable": current_user.role in {"owner", "admin"},
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
|
|
|
@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
|
|||
app_model = installed_app.app
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
|
|
@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
|
|||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
|
|
@ -4,7 +4,7 @@ from flask import session
|
|||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import str_len
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
|
|||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
|
|
|
@ -4,7 +4,7 @@ from flask import request
|
|||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import email, get_remote_ip, str_len
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
@ -40,7 +40,7 @@ class SetupApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
|
|||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
|
|
|
@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args["provider_name"] if args["provider_name"] else "",
|
||||
args["provider_name"] or "",
|
||||
args["tool_name"],
|
||||
args["credentials"],
|
||||
args["parameters"],
|
||||
|
|
|
@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
raise TooManyFilesError()
|
||||
|
||||
extension = file.filename.split(".")[-1]
|
||||
if extension.lower() not in ["svg", "png"]:
|
||||
if extension.lower() not in {"svg", "png"}:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
|
|
|
@ -64,7 +64,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
source = request.args.get("source")
|
||||
if source == "datasets":
|
||||
abort(403, "The number of documents has reached the limit of your subscription.")
|
||||
|
|
|
@ -38,6 +38,7 @@ class PluginInvokeLLMApi(Resource):
|
|||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
|
@ -113,6 +114,7 @@ class PluginInvokeNodeApi(Resource):
|
|||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode):
|
||||
pass
|
||||
|
||||
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
|
@ -134,6 +136,7 @@ class PluginInvokeAppApi(Resource):
|
|||
PluginAppBackwardsInvocation.convert_to_event_stream(response)
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
|
@ -145,6 +148,7 @@ class PluginInvokeEncryptApi(Resource):
|
|||
"""
|
||||
return PluginEncrypter.invoke_encrypt(tenant_model, payload)
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
|
||||
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
|
||||
|
|
|
@ -48,6 +48,7 @@ def get_tenant(view: Optional[Callable] = None):
|
|||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func):
|
||||
def decorated_view(*args, **kwargs):
|
||||
|
|
|
@ -63,6 +63,7 @@ def enterprise_inner_api_user_auth(view):
|
|||
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
|
|
@ -42,7 +42,7 @@ class AppParameterApi(Resource):
|
|||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
|
|
@ -79,19 +79,15 @@ class TextApi(Resource):
|
|||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
|
|
|
@ -96,7 +96,7 @@ class ChatApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
|
|
@ -18,7 +18,7 @@ class ConversationApi(Resource):
|
|||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
|
|||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
|
|||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
|
|
@ -76,7 +76,7 @@ class MessageListApi(Resource):
|
|||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
|
|||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
|
@ -22,10 +23,12 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -113,6 +116,30 @@ class WorkflowTaskStopApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
|
|
|
@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
|
|||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
|
|
@ -78,19 +78,15 @@ class TextApi(WebApiResource):
|
|||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
|
|||
class ChatApi(WebApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
|
|||
class ChatStopApi(WebApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
|
|
@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
|
|||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
|
|||
class ConversationApi(WebApiResource):
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
|
|||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
|
|||
class ConversationPinApi(WebApiResource):
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
|
|||
class ConversationUnPinApi(WebApiResource):
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
|
|
@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
|
|||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
|
|
@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code):
|
|||
if not source or source != "sso":
|
||||
raise WebSSOAuthRequiredError()
|
||||
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO,
|
||||
# raise an error and redirect to normal passport login
|
||||
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
||||
source = decoded.get("token_source")
|
||||
if source and source == "sso":
|
||||
|
|
|
@ -1 +1 @@
|
|||
import core.moderation.base
|
||||
import core.moderation.base
|
||||
|
|
|
@ -25,17 +25,19 @@ from models.model import Message
|
|||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append('Observation')
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(
|
||||
instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
|
@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(
|
||||
chunks, usage_dict)
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
|
@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip(
|
||||
) or 'I am thinking about how to help you'
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else {},
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ''
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(
|
||||
scratchpad.action.action_input)
|
||||
final_answer = json.dumps(scratchpad.action.action_input)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
|
@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={
|
||||
scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={
|
||||
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
|
@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint=''
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
|
@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action['action'],
|
||||
action_input=action['action_input']
|
||||
)
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
|
@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
|
@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
format assistant message
|
||||
"""
|
||||
message = ''
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
|
@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
|
@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
if not current_scratchpad:
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
|
@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(
|
||||
message.tool_calls[0].function.arguments)
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(
|
||||
current_scratchpad.action.to_dict()
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
|
@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
|
|
|
@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt \
|
||||
.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
|
@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content='')
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
|
@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([
|
||||
system_message,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
])
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
|
|
|
@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||
|
@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ''
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
|
@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = system_prompt \
|
||||
.replace("{{historic_messages}}", historic_prompt) \
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
|
|
|
@ -20,6 +20,7 @@ class AgentPromptEntity(BaseModel):
|
|||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
@ -33,6 +34,7 @@ class AgentScratchpadUnit(BaseModel):
|
|||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
|
@ -41,8 +43,8 @@ class AgentScratchpadUnit(BaseModel):
|
|||
Convert to dictionary.
|
||||
"""
|
||||
return {
|
||||
'action': self.action_name,
|
||||
'action_input': self.action_input,
|
||||
"action": self.action_name,
|
||||
"action_input": self.action_input,
|
||||
}
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
|
@ -56,10 +58,10 @@ class AgentScratchpadUnit(BaseModel):
|
|||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
'final' in self.action.action_name.lower() and
|
||||
'answer' in self.action.action_name.lower()
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
|
@ -69,8 +71,9 @@ class AgentEntity(BaseModel):
|
|||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
CHAIN_OF_THOUGHT = "chain-of-thought"
|
||||
FUNCTION_CALLING = "function-calling"
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
|
|
@ -24,11 +24,9 @@ from models.model import Message
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
|
@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
|
@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
|
@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
|
@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
|
@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
|
@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
|
@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls=[
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
|
@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
|
@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
|
@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response['tool_response'] is not None:
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response['tool_response'],
|
||||
content=tool_response["tool_response"],
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
thought=None,
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
|
@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
|
@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
|
@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
|
@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
|
@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
|
@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
|
@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
|
||||
|
@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [
|
||||
*self.history_prompt_messages,
|
||||
*query_prompt_messages,
|
||||
*self._current_thoughts
|
||||
]
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
|
|
@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
|||
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
|
||||
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str)
|
||||
|
@ -22,7 +23,7 @@ class CotAgentOutputParser:
|
|||
action = action[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if 'input' in key.lower():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
|
@ -33,37 +34,37 @@ class CotAgentOutputParser:
|
|||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json_str or ''
|
||||
return json_str or ""
|
||||
except:
|
||||
return json_str or ''
|
||||
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
action_cache = ''
|
||||
action_str = 'action:'
|
||||
action_cache = ""
|
||||
action_str = "action:"
|
||||
action_idx = 0
|
||||
|
||||
thought_cache = ''
|
||||
thought_str = 'thought:'
|
||||
thought_cache = ""
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict['usage'] = response.delta.usage
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
@ -72,24 +73,24 @@ class CotAgentOutputParser:
|
|||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
last_character = response[index-1] if index > 0 else ''
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
|
||||
if delta == '`':
|
||||
if delta == "`":
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
code_block_cache = ""
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
@ -97,7 +98,7 @@ class CotAgentOutputParser:
|
|||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
|
@ -105,18 +106,18 @@ class CotAgentOutputParser:
|
|||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
yield action_cache
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
@ -124,7 +125,7 @@ class CotAgentOutputParser:
|
|||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
|
@ -132,31 +133,31 @@ class CotAgentOutputParser:
|
|||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
yield thought_cache
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
elif delta == "}":
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
|
@ -172,12 +173,12 @@ class CotAgentOutputParser:
|
|||
if got_json:
|
||||
got_json = False
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
|
@ -186,4 +187,3 @@ class CotAgentOutputParser:
|
|||
|
||||
if json_cache:
|
||||
yield parse_action(json_cache)
|
||||
|
||||
|
|
|
@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
|||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
Thought:"""
|
||||
Thought:""" # noqa: E501
|
||||
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
|
@ -86,19 +87,20 @@ Action:
|
|||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
"english": {
|
||||
"chat": {
|
||||
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
"completion": {
|
||||
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,34 +26,24 @@ class BaseAppConfigManager:
|
|||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.file_upload = FileUploadConfigManager.convert(
|
||||
config=config_dict,
|
||||
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
||||
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
additional_features.opening_statement, additional_features.suggested_questions = \
|
||||
OpeningStatementConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.opening_statement, additional_features.suggested_questions = (
|
||||
OpeningStatementConfigManager.convert(config=config_dict)
|
||||
)
|
||||
|
||||
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
|
||||
|
||||
return additional_features
|
||||
|
|
|
@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
|
|||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
|
||||
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
if sensitive_word_avoidance_dict.get('enabled'):
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
|
||||
-> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
|
|||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=sensitive_word_avoidance_config
|
||||
)
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
return config, ["sensitive_word_avoidance"]
|
||||
|
|
|
@ -12,67 +12,70 @@ class AgentConfigManager:
|
|||
|
||||
:param config: model config args
|
||||
"""
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode']:
|
||||
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
elif agent_strategy in {"cot", "react"}:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config['model']['provider'] == 'openai':
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool.get('tool_parameters', {})
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
if 'strategy' in config['agent_mode'] and \
|
||||
config['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
# check model mode
|
||||
model_mode = config.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion'][
|
||||
'agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
|
||||
return AgentEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -15,39 +15,38 @@ class DatasetConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
dataset_ids = []
|
||||
if 'datasets' in config.get('dataset_configs', {}):
|
||||
datasets = config.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
if "datasets" in config.get("dataset_configs", {}):
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
for dataset in datasets.get("datasets", []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset['dataset']
|
||||
dataset = dataset["dataset"]
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
dataset_id = dataset.get("id", None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode'] \
|
||||
and config['agent_mode']['enabled']:
|
||||
if (
|
||||
"agent_mode" in config
|
||||
and config["agent_mode"]
|
||||
and "enabled" in config["agent_mode"]
|
||||
and config["agent_mode"]["enabled"]
|
||||
):
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
|
@ -55,30 +54,28 @@ class DatasetConfigManager:
|
|||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_id = tool_item["id"]
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
|
||||
# dataset configs
|
||||
if 'dataset_configs' in config and config.get('dataset_configs'):
|
||||
dataset_configs = config.get('dataset_configs')
|
||||
if "dataset_configs" in config and config.get("dataset_configs"):
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {
|
||||
'retrieval_model': 'multiple'
|
||||
}
|
||||
query_variable = config.get('dataset_query_variable')
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
)
|
||||
)
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return DatasetEntity(
|
||||
|
@ -86,15 +83,15 @@ class DatasetConfigManager:
|
|||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get('top_k', 4),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
|
||||
)
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold"),
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -111,13 +108,10 @@ class DatasetConfigManager:
|
|||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {
|
||||
"strategy": "router",
|
||||
"datasets": []
|
||||
}
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
@ -125,8 +119,9 @@ class DatasetConfigManager:
|
|||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
need_manual_query_datasets = (config.get("dataset_configs")
|
||||
and config["dataset_configs"].get("datasets", {}).get("datasets"))
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
@ -148,10 +143,7 @@ class DatasetConfigManager:
|
|||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
@ -175,7 +167,7 @@ class DatasetConfigManager:
|
|||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
has_datasets = False
|
||||
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
|
||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
|
@ -188,7 +180,7 @@ class DatasetConfigManager:
|
|||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
|
|
@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
|
|||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig,
|
||||
skip_check: bool = False) \
|
||||
-> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
|
@ -25,9 +23,7 @@ class ModelConfigConverter:
|
|||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id,
|
||||
provider=model_config.provider,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
|
@ -38,8 +34,7 @@ class ModelConfigConverter:
|
|||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_config.model
|
||||
model_type=ModelType.LLM, model=model_config.model
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
|
@ -51,8 +46,7 @@ class ModelConfigConverter:
|
|||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model,
|
||||
model_type=ModelType.LLM
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
|
@ -69,24 +63,18 @@ class ModelConfigConverter:
|
|||
# model config
|
||||
completion_params = model_config.parameters
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=model_config.model,
|
||||
credentials=model_credentials
|
||||
)
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_config.model,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
@ -13,23 +13,23 @@ class ModelConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
# model config
|
||||
model_config = config.get('model')
|
||||
model_config = config.get("model")
|
||||
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get('completion_params')
|
||||
completion_params = model_config.get("completion_params")
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.get('mode')
|
||||
model_mode = model_config.get("mode")
|
||||
|
||||
return ModelConfigEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
mode=model_mode,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
|
@ -43,7 +43,7 @@ class ModelConfigManager:
|
|||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
"""
|
||||
if 'model' not in config:
|
||||
if "model" not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
|
@ -52,17 +52,16 @@ class ModelConfigManager:
|
|||
# model.provider
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not models:
|
||||
|
@ -80,12 +79,12 @@ class ModelConfigManager:
|
|||
|
||||
# model.mode
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
config["model"]["mode"] = model_mode
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
config["model"]["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
if "completion_params" not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
|
@ -101,7 +100,7 @@ class ModelConfigManager:
|
|||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
if "stop" not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
|
|
@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
|
|||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = config.get("pre_prompt", "")
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = config.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
chat_prompt_messages.append(
|
||||
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
|
@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
|
|||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
|
|||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config['prompt_type'] not in prompt_type_vals:
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
||||
|
||||
# chat_prompt_config
|
||||
|
@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
|
|||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required "
|
||||
"when prompt_type is advanced")
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
||||
)
|
||||
|
||||
model_mode_vals = [mode.value for mode in ModelMode]
|
||||
if config['model']["mode"] not in model_mode_vals:
|
||||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
if config["model"]["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
|
|
|
@ -16,51 +16,49 @@ class BasicVariablesConfigManager:
|
|||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
external_data_tools = config.get("external_data_tools", [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
variable=external_data_tool["variable"],
|
||||
type=external_data_tool["type"],
|
||||
config=external_data_tool["config"],
|
||||
)
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variables in config.get('user_input_form', []):
|
||||
for variables in config.get("user_input_form", []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
if "config" not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
)
|
||||
)
|
||||
elif variable_type in [
|
||||
elif variable_type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
]:
|
||||
}:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description"),
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options"),
|
||||
default=variable.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -99,17 +97,17 @@ class BasicVariablesConfigManager:
|
|||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
|
||||
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
|
||||
form_item = item[key]
|
||||
if 'label' not in form_item:
|
||||
if "label" not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if 'variable' not in form_item:
|
||||
if "variable" not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
|
@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
|
|||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, "
|
||||
"and cannot start with a number")
|
||||
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if 'required' not in form_item or not form_item["required"]:
|
||||
if "required" not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if 'options' not in form_item or not form_item["options"]:
|
||||
if "options" not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item['default'] \
|
||||
and form_item["default"] not in form_item["options"]:
|
||||
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
return config, ["user_input_form"]
|
||||
|
@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
|
|||
typ = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
|
|
|
@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
|
|||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
|
@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
|
|||
"""
|
||||
Advanced Chat Message Entity.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
|
@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
|
|||
"""
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
|
@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
|||
"""
|
||||
Role Prefix Entity.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
|
@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
|
|||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
SIMPLE = 'simple'
|
||||
ADVANCED = 'advanced'
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptType':
|
||||
def value_of(cls, value: str) -> "PromptType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
|
|||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt type value {value}')
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: Optional[str] = None
|
||||
|
@ -87,7 +92,7 @@ class VariableEntityType(str, Enum):
|
|||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
|
@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
|
|||
"""
|
||||
External Data Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
SINGLE = 'single'
|
||||
MULTIPLE = 'multiple'
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid retrieve strategy value {value}')
|
||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||
|
||||
query_variable: Optional[str] = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = .0
|
||||
rerank_mode: Optional[str] = 'reranking_model'
|
||||
score_threshold: Optional[float] = 0.0
|
||||
rerank_mode: Optional[str] = "reranking_model"
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
reranking_enabled: Optional[bool] = True
|
||||
|
||||
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
"""
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
|
@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
|||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
|
|||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
|
|||
"""
|
||||
Tracing Config Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileExtraConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
|
@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
|
|||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
trace_config: Optional[TracingConfigEntity] = None
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""
|
||||
Application Config Entity.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
|
@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
|
|||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
ARGS = 'args'
|
||||
APP_LATEST_CONFIG = 'app-latest-config'
|
||||
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
|
||||
|
||||
ARGS = "args"
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
||||
class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Easy UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
|
@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
|||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
|
|
|
@ -13,21 +13,19 @@ class FileUploadConfigManager:
|
|||
:param config: model config args
|
||||
:param is_vision: if True, the feature is vision feature
|
||||
"""
|
||||
file_upload_dict = config.get('file_upload')
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get('image'):
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
if file_upload_dict.get("image"):
|
||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
||||
image_config = {
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
image_config['detail'] = file_upload_dict['image']['detail']
|
||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
||||
|
||||
return FileExtraConfig(
|
||||
image_config=image_config
|
||||
)
|
||||
return FileExtraConfig(image_config=image_config)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -49,21 +47,21 @@ class FileUploadConfigManager:
|
|||
if not config["file_upload"].get("image"):
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if config["file_upload"]["image"]["enabled"]:
|
||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
if is_vision:
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
detail = config["file_upload"]["image"]["detail"]
|
||||
if detail not in {"high", "low"}:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
if method not in {"remote_url", "local_file"}:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
|
|
@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get('more_like_this')
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get('enabled'):
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
|
@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
|
@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get('opening_statement')
|
||||
opening_statement = config.get("opening_statement")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get('suggested_questions')
|
||||
suggested_questions_list = config.get("suggested_questions")
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
|
|||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get('retriever_resource')
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
if retriever_resource_dict.get('enabled'):
|
||||
if retriever_resource_dict.get("enabled"):
|
||||
show_retrieve_source = True
|
||||
|
||||
return show_retrieve_source
|
||||
|
@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("retriever_resource"):
|
||||
config["retriever_resource"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["retriever_resource"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["retriever_resource"], dict):
|
||||
raise ValueError("retriever_resource must be of dict type")
|
||||
|
|
|
@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
speech_to_text = False
|
||||
speech_to_text_dict = config.get('speech_to_text')
|
||||
speech_to_text_dict = config.get("speech_to_text")
|
||||
if speech_to_text_dict:
|
||||
if speech_to_text_dict.get('enabled'):
|
||||
if speech_to_text_dict.get("enabled"):
|
||||
speech_to_text = True
|
||||
|
||||
return speech_to_text
|
||||
|
@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("speech_to_text"):
|
||||
config["speech_to_text"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["speech_to_text"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
|
|
@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
suggested_questions_after_answer = False
|
||||
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
|
||||
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
|
||||
if suggested_questions_after_answer_dict:
|
||||
if suggested_questions_after_answer_dict.get('enabled'):
|
||||
if suggested_questions_after_answer_dict.get("enabled"):
|
||||
suggested_questions_after_answer = True
|
||||
|
||||
return suggested_questions_after_answer
|
||||
|
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("suggested_questions_after_answer"):
|
||||
config["suggested_questions_after_answer"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["suggested_questions_after_answer"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if "enabled" not in config["suggested_questions_after_answer"] or not \
|
||||
config["suggested_questions_after_answer"]["enabled"]:
|
||||
if (
|
||||
"enabled" not in config["suggested_questions_after_answer"]
|
||||
or not config["suggested_questions_after_answer"]["enabled"]
|
||||
):
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
|
|
|
@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
text_to_speech_dict = config.get("text_to_speech")
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
if text_to_speech_dict.get("enabled"):
|
||||
text_to_speech = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
enabled=text_to_speech_dict.get("enabled"),
|
||||
voice=text_to_speech_dict.get("voice"),
|
||||
language=text_to_speech_dict.get("language"),
|
||||
)
|
||||
|
||||
return text_to_speech
|
||||
|
@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("text_to_speech"):
|
||||
config["text_to_speech"] = {
|
||||
"enabled": False,
|
||||
"voice": "",
|
||||
"language": ""
|
||||
}
|
||||
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
|
||||
|
||||
if not isinstance(config["text_to_speech"], dict):
|
||||
raise ValueError("text_to_speech must be of dict type")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
|
@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
|||
"""
|
||||
Advanced Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
|
@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
|||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
|
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
|
@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
|
@ -54,7 +56,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
|
@ -63,14 +66,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
) -> Union[dict[str, Any], Generator[dict | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str | dict, None, None]:
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str | dict, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -81,44 +84,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
conversation = self._get_conversation_by_user(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
|
@ -140,7 +136,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -150,16 +146,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account | EndUser, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -171,16 +163,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
|
@ -188,18 +177,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_config=app_config,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -209,17 +195,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=None,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -235,10 +223,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
is_first_conversation = True
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
|
@ -253,18 +238,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": contextvars.copy_context(),
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -278,18 +266,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
user=user,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -312,22 +300,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -373,7 +360,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
|
@ -21,14 +21,11 @@ class AudioTrunk:
|
|||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
|
|||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
audio_queue.put(AudioTrunk("finish", b""))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
|
@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
|
|||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
|
@ -94,28 +90,27 @@ class AppGeneratorTTSPublisher:
|
|||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
self.msg_text += message.event.outputs.get('output', '')
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
|
|
@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -66,14 +66,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
|
@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
|
@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
|
@ -174,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
|
@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
|
@ -216,19 +217,15 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._complete_with_stream_output(
|
||||
text=str(e),
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
)
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
|
@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
self._publish_event(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
|
||||
)
|
||||
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
|
||||
|
||||
self._complete_with_stream_output(
|
||||
text=annotation_reply.content,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _complete_with_stream_output(self,
|
||||
text: str,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
)
|
||||
)
|
||||
self._publish_event(QueueTextChunkEvent(text=text))
|
||||
|
||||
self._publish_event(
|
||||
QueueStopEvent(stopped_by=stopped_by)
|
||||
)
|
||||
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
|
||||
|
|
|
@ -27,15 +27,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -49,13 +49,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[dict | str, Any, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
|
@ -66,14 +68,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -84,7 +86,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[dict | str, Any, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -95,20 +99,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
|
@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
|
@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
|
@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
extras['metadata'] = stream_response.metadata
|
||||
extras["metadata"] = stream_response.metadata
|
||||
|
||||
return ChatbotAppBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
|
@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
message_id=self._message.id,
|
||||
answer=self._task_state.answer,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
@ -176,32 +176,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
@ -214,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
|
@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
db.session.close()
|
||||
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
|
@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
|
@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
|
@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
# Save message
|
||||
|
@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
|
@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
|
@ -466,13 +451,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
|
@ -502,8 +489,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
|
@ -523,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
|
@ -533,15 +521,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
@ -555,14 +541,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
|
|||
"""
|
||||
Agent Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
agent: Optional[AgentEntity] = None
|
||||
|
||||
|
||||
class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
:param app_model: app model
|
||||
|
@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
agent=AgentConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
agent=AgentConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
|
@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
|
@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# dataset configs
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
if not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [member.value for member in
|
||||
list(PlanningStrategy.__members__.values())]:
|
||||
if config["agent_mode"]["strategy"] not in [
|
||||
member.value for member in list(PlanningStrategy.__members__.values())
|
||||
]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if not config["agent_mode"].get("tools"):
|
||||
|
@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if key == "dataset":
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
|
|
@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
|||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
|
@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
|
|||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -39,7 +40,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -48,19 +50,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = False,
|
||||
) -> dict | Generator[dict | str, None, None]: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict | str, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
) -> Union[dict, Generator[dict | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -71,60 +71,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param stream: is stream
|
||||
"""
|
||||
if not stream:
|
||||
raise ValueError('Agent Chat App does not support blocking mode')
|
||||
raise ValueError("Agent Chat App does not support blocking mode")
|
||||
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = AgentChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {
|
||||
"enabled": True
|
||||
}
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
@ -133,7 +121,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
|
@ -154,14 +142,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -170,17 +155,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -194,13 +182,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self, flask_app: Flask,
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
|
@ -229,18 +215,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
|
|
@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
|
|||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
|
@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
|
|||
"""
|
||||
|
||||
def run(
|
||||
self, application_generate_entity: AgentChatAppGenerateEntity,
|
||||
self,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
|
@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
|
@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
|
|||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
|
@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# moderation
|
||||
|
@ -103,15 +101,15 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
|
|||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
|
@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
|
@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
|
@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
|
|||
agent_entity = app_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id)
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
|
@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
invoke_result = runner.run(
|
||||
|
@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
|
|||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id
|
||||
).first()
|
||||
tool_variables: ToolConversationVariables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
|
@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
|
|||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str='[]',
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(**{
|
||||
'conversation_id': db_variables.conversation_id,
|
||||
'user_id': db_variables.user_id,
|
||||
'tenant_id': db_variables.tenant_id,
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
|
@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
|
|||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
|
|
@ -22,15 +22,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -44,8 +44,8 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
@ -62,14 +62,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -92,20 +92,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -13,11 +13,10 @@ class AppGenerateResponseConverter(ABC):
|
|||
_blocking_response_type: type[AppBlockingResponse]
|
||||
|
||||
@classmethod
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
@ -52,8 +51,9 @@ class AppGenerateResponseConverter(ABC):
|
|||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -64,24 +64,26 @@ class AppGenerateResponseConverter(ABC):
|
|||
:return:
|
||||
"""
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in metadata:
|
||||
metadata['retriever_resources'] = []
|
||||
for resource in metadata['retriever_resources']:
|
||||
metadata['retriever_resources'].append({
|
||||
'segment_id': resource['segment_id'],
|
||||
'position': resource['position'],
|
||||
'document_name': resource['document_name'],
|
||||
'score': resource['score'],
|
||||
'content': resource['content'],
|
||||
})
|
||||
if "retriever_resources" in metadata:
|
||||
metadata["retriever_resources"] = []
|
||||
for resource in metadata["retriever_resources"]:
|
||||
metadata["retriever_resources"].append(
|
||||
{
|
||||
"segment_id": resource["segment_id"],
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"content": resource["content"],
|
||||
}
|
||||
)
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in metadata:
|
||||
del metadata['annotation_reply']
|
||||
if "annotation_reply" in metadata:
|
||||
del metadata["annotation_reply"]
|
||||
|
||||
# show usage
|
||||
if 'usage' in metadata:
|
||||
del metadata['usage']
|
||||
if "usage" in metadata:
|
||||
del metadata["usage"]
|
||||
|
||||
return metadata
|
||||
|
||||
|
@ -93,16 +95,16 @@ class AppGenerateResponseConverter(ABC):
|
|||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
"code": "provider_quota_exceeded",
|
||||
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
"status": 400,
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
|
||||
InvokeError: {"code": "completion_request_error", "status": 400},
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
|
@ -112,13 +114,13 @@ class AppGenerateResponseConverter(ABC):
|
|||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
data.setdefault("message", getattr(e, "description", str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
"code": "internal_server_error",
|
||||
"message": "Internal Server Error, please contact support.",
|
||||
"status": 500,
|
||||
}
|
||||
|
||||
return data
|
||||
|
|
|
@ -17,17 +17,17 @@ class BaseAppGenerator:
|
|||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.variable} is required in input form')
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
return var.default or ""
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
)
|
||||
}
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
):
|
||||
|
@ -35,7 +35,7 @@ class BaseAppGenerator:
|
|||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
if "." in user_input_value:
|
||||
return float(user_input_value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
|
@ -44,20 +44,20 @@ class BaseAppGenerator:
|
|||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
|
||||
return user_input_value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.replace('\x00', '')
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def convert_to_event_stream(cls, generator: Union[dict, Generator[dict| str, None, None]]):
|
||||
def convert_to_event_stream(cls, generator: Union[dict, Generator[dict | str, None, None]]):
|
||||
"""
|
||||
Convert messages into event stream
|
||||
"""
|
||||
|
|
|
@ -24,9 +24,7 @@ class PublishFrom(Enum):
|
|||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
|
@ -34,9 +32,10 @@ class AppQueueManager:
|
|||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
|
||||
f"{user_prefix}-{self._user_id}")
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
|
@ -66,8 +65,7 @@ class AppQueueManager:
|
|||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
|
@ -88,9 +86,7 @@ class AppQueueManager:
|
|||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
), pub_from)
|
||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
|
@ -122,8 +118,8 @@ class AppQueueManager:
|
|||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
|
@ -168,10 +164,12 @@ class AppQueueManager:
|
|||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
|
|
@ -31,12 +31,15 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
|
@ -49,18 +52,20 @@ class AppRunner:
|
|||
"""
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
@ -75,36 +80,39 @@ class AppRunner:
|
|||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size."
|
||||
)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
def recalc_llm_max_tokens(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
@ -112,27 +120,28 @@ class AppRunner:
|
|||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def organize_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def organize_prompt_messages(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
|
@ -152,60 +161,54 @@ class AppRunner:
|
|||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
prompt_template = CompletionModelPromptTemplate(
|
||||
text=advanced_completion_prompt_template.prompt
|
||||
)
|
||||
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
memory_config.role_prefix = MemoryConfig.RolePrefix(
|
||||
user=advanced_completion_prompt_template.role_prefix.user,
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant,
|
||||
)
|
||||
else:
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(
|
||||
text=message.text,
|
||||
role=message.role
|
||||
))
|
||||
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def direct_output(self, queue_manager: AppQueueManager,
|
||||
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None) -> None:
|
||||
def direct_output(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
|
@ -222,17 +225,10 @@ class AppRunner:
|
|||
chunk = LLMResultChunk(
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=chunk
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
|
@ -242,15 +238,19 @@ class AppRunner:
|
|||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
),
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False) -> None:
|
||||
def _handle_invoke_result(
|
||||
self,
|
||||
invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
@ -260,21 +260,13 @@ class AppRunner:
|
|||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool) -> None:
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
|
@ -285,12 +277,13 @@ class AppRunner:
|
|||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=invoke_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool) -> None:
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
@ -300,21 +293,13 @@ class AppRunner:
|
|||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
text = ''
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueAgentMessageEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
|
@ -331,25 +316,24 @@ class AppRunner:
|
|||
usage = LLMUsage.empty_usage()
|
||||
|
||||
llm_result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage
|
||||
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=llm_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def moderation_for_inputs(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
|
@ -367,14 +351,17 @@ class AppRunner:
|
|||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query or "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
def check_hosting_moderation(
|
||||
self,
|
||||
application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -384,8 +371,7 @@ class AppRunner:
|
|||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
|
@ -393,18 +379,20 @@ class AppRunner:
|
|||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
def fill_in_inputs_from_external_data_tools(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
|
@ -417,18 +405,12 @@ class AppRunner:
|
|||
"""
|
||||
external_data_fetch_feature = ExternalDataFetch()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
def query_app_annotations_to_reply(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
|
@ -440,9 +422,5 @@ class AppRunner:
|
|||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
|
||||
)
|
||||
|
|
|
@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
|
|||
"""
|
||||
Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
:param app_model: app model
|
||||
|
@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
if not override_config_dict:
|
||||
raise Exception('override_config_dict is required when config_from is ARGS')
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
|
||||
config_dict = override_config_dict
|
||||
|
||||
|
@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
|
@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# opening_statement
|
||||
|
@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
|
@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
|||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
|
@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
|
|||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -56,7 +58,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
) -> Union[dict, Generator[dict | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -71,58 +74,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {
|
||||
"enabled": True
|
||||
}
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
@ -131,7 +122,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
|
@ -150,14 +141,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -166,17 +154,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -190,16 +181,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -221,20 +212,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
|
|
@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
|||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
|
@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
|
|||
Chat Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
|
@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
|
|||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
|
@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# moderation
|
||||
|
@ -96,15 +96,15 @@ class ChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
|
|||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
|
@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
|
@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
|
@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
|
|||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
|
|||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
|
|
|
@ -22,15 +22,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -44,8 +44,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
@ -62,14 +62,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -92,20 +92,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
|||
"""
|
||||
Completion App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
:param app_model: app model
|
||||
|
@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
|
@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
|
@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
|||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
|
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
|||
class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -41,7 +42,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -72,12 +74,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {}
|
||||
|
||||
|
@ -85,41 +87,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
conversation = None
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = CompletionAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_model_config_dict
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
|
@ -137,14 +129,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -153,16 +142,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -176,15 +168,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -203,20 +195,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -225,12 +216,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
finally:
|
||||
db.session.close()
|
||||
|
||||
def generate_more_like_this(self, app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
def generate_more_like_this(
|
||||
self,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -240,13 +233,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
@ -259,29 +256,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
app_model_config = message.app_model_config
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict['model']
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
override_model_config_dict['model'] = model_dict
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
message.files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_model_config_dict
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
|
@ -295,14 +286,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={}
|
||||
extras={},
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -311,16 +299,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -334,7 +325,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
|
|
@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Message
|
||||
|
@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
|
|||
Completion Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
|
@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# moderation
|
||||
|
@ -77,15 +77,15 @@ class CompletionAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
|
|||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
|
@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
|
|||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_config = app_config.dataset
|
||||
|
@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
|
@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context
|
||||
context=context,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
|
|||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
|
|
@ -22,14 +22,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -43,8 +43,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
@ -61,13 +61,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -90,19 +90,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlalchemy import and_
|
|||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
|
@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _handle_response(
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Handle response.
|
||||
|
@ -70,24 +70,25 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
|
||||
user: Union[Account, EndUser]) -> Conversation:
|
||||
def _get_conversation_by_user(
|
||||
self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
|
||||
) -> Conversation:
|
||||
conversation_filter = [
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == 'normal'
|
||||
Conversation.status == "normal",
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
|
@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
if conversation.status != "normal":
|
||||
raise ConversationCompletedError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_app_model_config(self, app_model: App,
|
||||
conversation: Optional[Conversation] = None) \
|
||||
-> AppModelConfig:
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
|
||||
return app_model_config
|
||||
|
||||
def _init_generate_records(self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
conversation: Optional[Conversation] = None) \
|
||||
-> tuple[Conversation, Message]:
|
||||
def _init_generate_records(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
conversation: Optional[Conversation] = None,
|
||||
) -> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -147,11 +148,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
# get from source
|
||||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
from_source = 'api'
|
||||
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
from_source = "api"
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = 'console'
|
||||
from_source = "console"
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
|
@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
model_provider = application_generate_entity.model_conf.provider
|
||||
model_id = application_generate_entity.model_conf.model
|
||||
override_model_configs = None
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
|
||||
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT,
|
||||
AppMode.COMPLETION,
|
||||
}:
|
||||
override_model_configs = app_config.app_model_config_dict
|
||||
|
||||
# get conversation introduction
|
||||
|
@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name='New conversation',
|
||||
name="New conversation",
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status='normal',
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
|
@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
|
@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
belongs_to='user',
|
||||
belongs_to="user",
|
||||
url=file.url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by=account_id or end_user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
|
@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.id == message_id)
|
||||
.first()
|
||||
)
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
return message
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
|
|||
|
||||
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
def __init__(
|
||||
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
|
||||
) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._conversation_id = str(conversation_id)
|
||||
|
@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
event=event,
|
||||
)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
|
@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
event=event,
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueAdvancedChatMessageEndEvent):
|
||||
if isinstance(
|
||||
event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
|
||||
raise GenerateTaskStoppedError()
|
||||
|
|
|
@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
|
|||
"""
|
||||
Workflow App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
|||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
|
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
|
@ -46,14 +47,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict: ...
|
||||
|
||||
@overload
|
||||
|
@ -76,7 +78,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
@ -90,26 +92,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
inputs = args["inputs"]
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
|
@ -125,7 +120,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -136,11 +131,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self, *,
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
|
@ -165,17 +161,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context(),
|
||||
'workflow_thread_pool_id': workflow_thread_pool_id
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -188,10 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
|
@ -210,16 +206,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
|
@ -230,13 +223,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -246,14 +236,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -270,22 +263,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -294,14 +286,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) -> Union[
|
||||
WorkflowAppBlockingResponse,
|
||||
Generator[WorkflowAppStreamResponse, None, None]
|
||||
]:
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -317,14 +309,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
|
|||
|
||||
|
||||
class WorkflowAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
app_mode: str) -> None:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
@ -27,20 +24,19 @@ class WorkflowAppQueueManager(AppQueueManager):
|
|||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(
|
||||
task_id=self._task_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent):
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
|
|
@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
|
@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
|
@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
|
@ -114,18 +113,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
@ -46,12 +46,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -74,12 +74,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
|
|
@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> WorkflowAppBlockingResponse:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
|
@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at)
|
||||
)
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
@ -158,34 +160,34 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id=workflow_run_id,
|
||||
stream_response=stream_response
|
||||
)
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
@ -197,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
|
@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
|
||||
outputs=json.dumps(event.outputs)
|
||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
||||
else None,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
|
@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
|
@ -387,14 +376,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
|
@ -417,14 +407,16 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
|
||||
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
|
||||
workflow_app_log.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse:
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
|
@ -432,7 +424,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
"""
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text)
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
|
@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(
|
||||
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
|
||||
)
|
||||
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowSucceededEvent(outputs=event.outputs)
|
||||
)
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error)
|
||||
)
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
|
@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
|
@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
in_iteration_id=event.in_iteration_id
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
|
@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
|
@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
|
@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
|
@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
|
@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
|
@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
|
@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
@ -30,169 +30,150 @@ _TEXT_COLOR_MAPPING = {
|
|||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color='pink')
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color='green')
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_node_execute_started(event=event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_node_execute_succeeded(event=event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_node_execute_failed(event=event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(
|
||||
event=event
|
||||
)
|
||||
self.on_node_text_chunk(event=event)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_parallel_started(event=event)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_parallel_completed(event=event)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_iteration_started(event=event)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_node_execute_started(
|
||||
self,
|
||||
event: NodeRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='yellow')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='yellow')
|
||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(
|
||||
self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='green')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='green')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='green')
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='green')
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color='green')
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(
|
||||
self,
|
||||
event: NodeRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='red')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='red')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='red')
|
||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='red')
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def on_node_text_chunk(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent
|
||||
) -> None:
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
||||
)
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(
|
||||
self,
|
||||
event: ParallelBranchRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self,
|
||||
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = 'blue'
|
||||
color = "blue"
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = 'red'
|
||||
color = "red"
|
||||
|
||||
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||
self.print_text(
|
||||
"\n[ParallelBranchRunSucceededEvent]"
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
||||
else "\n[ParallelBranchRunFailedEvent]",
|
||||
color=color,
|
||||
)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
|
@ -201,43 +182,37 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(
|
||||
self,
|
||||
event: IterationRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(
|
||||
self,
|
||||
event: IterationRunNextEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(
|
||||
self,
|
||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(
|
||||
"\n[IterationRunSucceededEvent]"
|
||||
if isinstance(event, IterationRunSucceededEvent)
|
||||
else "\n[IterationRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = "\n"
|
||||
) -> None:
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f'{text_to_print}', end=end)
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
|
|
|
@ -15,13 +15,14 @@ class InvokeFrom(Enum):
|
|||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
SERVICE_API = 'service-api'
|
||||
WEB_APP = 'web-app'
|
||||
EXPLORE = 'explore'
|
||||
DEBUGGER = 'debugger'
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'InvokeFrom':
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
@ -31,7 +32,7 @@ class InvokeFrom(Enum):
|
|||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid invoke from value {value}')
|
||||
raise ValueError(f"invalid invoke from value {value}")
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""
|
||||
|
@ -40,21 +41,22 @@ class InvokeFrom(Enum):
|
|||
:return: source
|
||||
"""
|
||||
if self == InvokeFrom.WEB_APP:
|
||||
return 'web_app'
|
||||
return "web_app"
|
||||
elif self == InvokeFrom.DEBUGGER:
|
||||
return 'dev'
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return 'explore_app'
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return 'api'
|
||||
return "api"
|
||||
|
||||
return 'dev'
|
||||
return "dev"
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
"""
|
||||
Model Config With Credentials Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
model_schema: AIModelEntity
|
||||
|
@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
|
|||
"""
|
||||
App Generate Entity.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
|
@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: EasyUIBasedAppConfig
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
"""
|
||||
Completion Application Generate Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
"""
|
||||
Agent Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
|
@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Workflow Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
|
@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user