mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: universal chat in explore (#649)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
parent
94b54b7ca9
commit
4fdb37771a
|
@ -19,7 +19,7 @@ def check_file_for_chinese_comments(file_path):
|
|||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py']
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
|
|
|
@ -22,7 +22,7 @@ from extensions.ext_database import db
|
|||
from extensions.ext_login import login_manager
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import model, account, dataset, web, task, source
|
||||
from models import model, account, dataset, web, task, source, tool
|
||||
from events import event_handlers
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
|
|
@ -18,7 +18,10 @@ from .auth import login, oauth, data_source_oauth, activate
|
|||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, providers, account
|
||||
from .workspace import workspace, members, model_providers, account, tool_providers
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
|
|
@ -24,6 +24,7 @@ model_config_fields = {
|
|||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'pre_prompt': fields.String,
|
||||
|
@ -96,7 +97,8 @@ class AppListApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id).order_by(App.created_at.desc()),
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False)
|
||||
|
@ -147,6 +149,7 @@ class AppListApi(Resource):
|
|||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
|
@ -438,6 +441,7 @@ class AppCopy(Resource):
|
|||
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
|
||||
speech_to_text=app_config.speech_to_text,
|
||||
more_like_this=app_config.more_like_this,
|
||||
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
|
||||
model=app_config.model,
|
||||
user_input_form=app_config.user_input_form,
|
||||
pre_prompt=app_config.pre_prompt,
|
||||
|
|
|
@ -163,7 +163,7 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
@ -322,7 +322,7 @@ class ChatConversationApi(Resource):
|
|||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
|
|
@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
|
|||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
|
|
|
@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource):
|
|||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
|
|||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import InstalledApp
|
||||
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
@ -27,16 +31,17 @@ class AppParameterApi(InstalledAppResource):
|
|||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, installed_app):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
|
66
api/controllers/console/universal_chat/audio.py
Normal file
66
api/controllers/console/universal_chat/audio.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class UniversalChatAudioApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
|
127
api/controllers/console/universal_chat/chat.py
Normal file
127
api/controllers/console/universal_chat/chat.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class UniversalChatApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
|
||||
if not llm_constant.models[args['model']]:
|
||||
raise ValueError("Model not exists.")
|
||||
|
||||
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
args['inputs'] = {}
|
||||
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
PubHandler.stop(current_user, task_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
118
api/controllers/console/universal_chat/conversation.py
Normal file
118
api/controllers/console/universal_chat/conversation.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationApi(UniversalChatResource):
|
||||
def delete(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
127
api/controllers/console/universal_chat/message.py
Normal file
127
api/controllers/console/universal_chat/message.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound, InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class UniversalChatMessageListApi(UniversalChatResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||
def post(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
def get(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
36
api/controllers/console/universal_chat/parameter.py
Normal file
36
api/controllers/console/universal_chat/parameter.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Resource for app variables."""
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, universal_app: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
84
api/controllers/console/universal_chat/wraps.py
Normal file
84
api/controllers/console/universal_chat/wraps.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
def universal_chat_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# get universal chat app
|
||||
universal_app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == True
|
||||
).first()
|
||||
|
||||
if universal_app is None:
|
||||
# create universal app if not exists
|
||||
universal_app = App(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name='Universal Chat',
|
||||
mode='chat',
|
||||
is_universal=True,
|
||||
icon='',
|
||||
icon_background='',
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
status='normal'
|
||||
)
|
||||
|
||||
db.session.add(universal_app)
|
||||
db.session.flush()
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement='',
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-16k",
|
||||
"completion_params": {
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([]),
|
||||
pre_prompt='',
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||
)
|
||||
|
||||
app_model_config.app_id = universal_app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
universal_app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return view(universal_app, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class UniversalChatResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
136
api/controllers/console/workspace/tool_providers.py
Normal file
136
api/controllers/console/workspace/tool_providers.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||
from extensions.ext_database import db
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_credential_dict = {}
|
||||
for tool_name in ToolProviderName:
|
||||
tool_credential_dict[tool_name.value] = {
|
||||
'tool_name': tool_name.value,
|
||||
'is_enabled': False,
|
||||
'credentials': None
|
||||
}
|
||||
|
||||
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for p in tool_providers:
|
||||
if p.is_enabled:
|
||||
tool_credential_dict[p.tool_name] = {
|
||||
'tool_name': p.tool_name,
|
||||
'is_enabled': p.is_enabled,
|
||||
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||
}
|
||||
|
||||
return list(tool_credential_dict.values())
|
||||
|
||||
|
||||
class ToolProviderCredentialsApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||
f'current_role is {current_user.current_tenant.current_role}')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == tenant.id,
|
||||
ToolProvider.tool_name == provider,
|
||||
).first()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if tool_provider_model:
|
||||
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||
tool_provider_model.is_enabled = True
|
||||
else:
|
||||
tool_provider_model = ToolProvider(
|
||||
tenant_id=tenant.id,
|
||||
tool_name=provider,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
is_enabled=True
|
||||
)
|
||||
db.session.add(tool_provider_model)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ToolProviderCredentialsValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
|
@ -4,6 +4,10 @@ from flask_restful import fields, marshal_with
|
|||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
@ -28,15 +32,16 @@ class AppParameterApi(AppApiResource):
|
|||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
|
|
@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
|
|||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
@ -27,15 +31,16 @@ class AppParameterApi(WebApiResource):
|
|||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
|
|
@ -62,7 +62,10 @@ class ConversationApi(WebApiResource):
|
|||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
35
api/core/agent/agent/calc_token_mixin.py
Normal file
35
api/core/agent/agent/calc_token_mixin.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from typing import cast, List
|
||||
|
||||
from langchain import OpenAI
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
return llm.get_num_tokens_from_messages(messages)
|
||||
|
||||
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param llm:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
completion_max_tokens = llm.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
|
||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
84
api/core/agent/agent/multi_dataset_router_agent.py
Normal file
84
api/core/agent/agent/multi_dataset_router_agent.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
llm.model_name = 'gpt-3.5-turbo'
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
**kwargs,
|
||||
)
|
120
api/core/agent/agent/openai_function_call.py
Normal file
120
api/core/agent/agent/openai_function_call.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
import pytz
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
_format_intermediate_steps
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
current_time = datetime.now()
|
||||
current_timezone = pytz.timezone('UTC')
|
||||
current_time = current_timezone.localize(current_time)
|
||||
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"Current time: {}\n"
|
||||
"Respond directly if appropriate.".format(
|
||||
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
132
api/core/agent/agent/openai_function_call_summarize_mixin.py
Normal file
132
api/core/agent/agent/openai_function_call_summarize_mixin.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
from typing import cast, List
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
|
||||
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
109
api/core/agent/agent/openai_multi_function_call.py
Normal file
109
api/core/agent/agent/openai_multi_function_call.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
import pytz
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||
_parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseMultiActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
current_time = datetime.now()
|
||||
current_timezone = pytz.timezone('UTC')
|
||||
current_time = current_timezone.localize(current_time)
|
||||
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"Current time: {}\n"
|
||||
"Respond directly if appropriate.".format(
|
||||
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
|
29
api/core/agent/agent/output_parser/structured_chat.py
Normal file
29
api/core/agent/agent/output_parser/structured_chat.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import json
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
|
||||
logger
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
try:
|
||||
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL)
|
||||
if action_match is not None:
|
||||
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||
if isinstance(response, list):
|
||||
# gpt turbo frequently ignores the directive to emit a single action
|
||||
logger.warning("Got multiple action responses: %s", response)
|
||||
response = response[0]
|
||||
if response["action"] == "Final Answer":
|
||||
return AgentFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
return AgentAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
182
api/core/agent/agent/structured_chat.py
Normal file
182
api/core/agent/agent/structured_chat.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
**kwargs,
|
||||
)
|
|
@ -1,86 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
|
||||
class AgentBuilder:
|
||||
@classmethod
|
||||
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=agent_loop_gather_callback_handler.model_name,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
tool.callbacks = [
|
||||
agent_loop_gather_callback_handler,
|
||||
dataset_tool_callback_handler,
|
||||
DifyStdOutCallbackHandler()
|
||||
]
|
||||
|
||||
prompt = cls.build_agent_prompt_template(
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
agent_llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
|
||||
|
||||
agent_callback_manager = CallbackManager(
|
||||
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_chain = AgentExecutor.from_agent_and_tools(
|
||||
tools=tools,
|
||||
agent=agent,
|
||||
memory=memory,
|
||||
callbacks=agent_callback_manager,
|
||||
max_iterations=6,
|
||||
early_stopping_method="generate",
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
)
|
||||
|
||||
return agent_chain
|
||||
|
||||
@classmethod
|
||||
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
prompt = ConversationalAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
agent = ConversationalAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
else:
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
|
||||
return agent
|
121
api/core/agent/agent_executor.py
Normal file
121
api/core/agent/agent_executor.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
import enum
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class PlanningStrategy(str, enum.Enum):
|
||||
ROUTER = 'router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
llm: BaseLanguageModel
|
||||
tools: list[BaseTool]
|
||||
summary_llm: BaseLanguageModel
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentExecuteResult(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
output: Optional[str]
|
||||
configuration: AgentConfiguration
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
def __init__(self, configuration: AgentConfiguration):
|
||||
self.configuration = configuration
|
||||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
||||
|
||||
return agent
|
||||
|
||||
def should_use_agent(self, query: str) -> bool:
|
||||
return self.agent.should_use_agent(query)
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
memory=self.configuration.memory,
|
||||
max_iterations=self.configuration.max_iterations,
|
||||
max_execution_time=self.configuration.max_execution_time,
|
||||
early_stopping_method=self.configuration.early_stopping_method,
|
||||
callbacks=self.configuration.callbacks
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except Exception:
|
||||
logging.exception("agent_executor run failed")
|
||||
output = None
|
||||
|
||||
return AgentExecuteResult(
|
||||
output=output,
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
|
@ -1,10 +1,12 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
|
@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
def clear_agent_loops(self) -> None:
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
|
@ -61,9 +65,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
if 'function_call' in completion_message.additional_kwargs:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
||||
else:
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
else:
|
||||
self._current_loop.completion = completion_generation.text
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
|
@ -89,15 +106,29 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
tool_input = json.dumps({"query": action.tool_input}
|
||||
if isinstance(action.tool_input, str) else action.tool_input)
|
||||
completion = None
|
||||
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
|
||||
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
|
||||
thought = action.log.strip()
|
||||
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
|
||||
else:
|
||||
action_name_position = action.log.index("Action:") if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'llm_end':
|
||||
self._current_loop.status = 'agent_action'
|
||||
self._current_loop.thought = thought
|
||||
self._current_loop.tool_name = tool
|
||||
self._current_loop.tool_input = tool_input
|
||||
if completion is not None:
|
||||
self._current_loop.completion = completion
|
||||
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
|
@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
self._current_loop.thought = '[DONE]'
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
|||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name = serialized.get('name')
|
||||
dataset_id = tool_name[len("dataset-"):]
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
|
||||
# tool_name = serialized.get('name')
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
dataset_id = input_dict.get('dataset_id')
|
||||
query = input_dict.get('query')
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
|
|
|
@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
|
|||
tool_output: str = None
|
||||
|
||||
prompt: str = None
|
||||
prompt_tokens: int = None
|
||||
prompt_tokens: int = 0
|
||||
completion: str = None
|
||||
completion_tokens: int = None
|
||||
completion_tokens: int = 0
|
||||
|
||||
latency: float = None
|
||||
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def __init__(self, llm: BaseLanguageModel,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.llm = llm
|
||||
self.llm_message = LLMMessage()
|
||||
|
|
|
@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
|||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
|
||||
llm_constant.agent_model_name,
|
||||
conversation_message_task
|
||||
)
|
||||
self.agent_callback = None
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.agent_loop_gather_callback_handler.current_chain = None
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
|
@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
|||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.chain.tool_chain import ToolChain
|
||||
|
||||
|
||||
class ChainBuilder:
|
||||
@classmethod
|
||||
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
|
||||
return ToolChain(
|
||||
tool=tool,
|
||||
input_key=kwargs.get('input_key', 'input'),
|
||||
output_key=kwargs.get('output_key', 'tool_output'),
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
|
||||
SensitiveWordAvoidanceChain]:
|
||||
sensitive_words = tool_config.get("words", "")
|
||||
if tool_config.get("enabled", False) \
|
||||
and sensitive_words:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=tool_config.get("canned_response", ''),
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=[DifyStdOutCallbackHandler()],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
|
@ -1,111 +0,0 @@
|
|||
"""Base classes for LLM-powered router chains."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
|
||||
class Route(NamedTuple):
|
||||
destination: Optional[str]
|
||||
next_inputs: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMRouterChain(Chain):
|
||||
"""A router chain that uses an LLM chain to perform routing."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM chain used to perform routing"""
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt(cls, values: dict) -> dict:
|
||||
prompt = values["llm_chain"].prompt
|
||||
if prompt.output_parser is None:
|
||||
raise ValueError(
|
||||
"LLMRouterChain requires base llm_chain prompt to have an output"
|
||||
" parser that converts LLM text output to a dictionary with keys"
|
||||
" 'destination' and 'next_inputs'. Received a prompt with no output"
|
||||
" parser."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.llm_chain.input_keys
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
super()._validate_outputs(outputs)
|
||||
if not isinstance(outputs["next_inputs"], dict):
|
||||
raise ValueError
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
output = cast(
|
||||
Dict[str, Any],
|
||||
self.llm_chain.predict_and_parse(**inputs),
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
||||
) -> LLMRouterChain:
|
||||
"""Convenience constructor."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["destination", "next_inputs"]
|
||||
|
||||
def route(self, inputs: Dict[str, Any]) -> Route:
|
||||
result = self(inputs)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
|
||||
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
||||
"""Parser for output of router chain int he multi-prompt chain."""
|
||||
|
||||
default_destination: str = "DEFAULT"
|
||||
next_inputs_type: Type = str
|
||||
next_inputs_inner_key: str = "input"
|
||||
|
||||
def parse(self, text: str) -> Dict[str, Any]:
|
||||
try:
|
||||
expected_keys = ["destination", "next_inputs"]
|
||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||
if not isinstance(parsed["destination"], str):
|
||||
raise ValueError("Expected 'destination' to be a string.")
|
||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||
raise ValueError(
|
||||
f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
||||
)
|
||||
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
||||
if (
|
||||
parsed["destination"].strip().lower()
|
||||
== self.default_destination.lower()
|
||||
):
|
||||
parsed["destination"] = None
|
||||
else:
|
||||
parsed["destination"] = parsed["destination"].strip()
|
||||
return parsed
|
||||
except Exception as e:
|
||||
raise OutputParserException(
|
||||
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
|
||||
)
|
|
@ -1,110 +0,0 @@
|
|||
from typing import Optional, List, cast
|
||||
|
||||
from langchain.chains import SequentialChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.chain_builder import ChainBuilder
|
||||
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MainChainBuilder:
|
||||
@classmethod
|
||||
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
first_input_key = "input"
|
||||
final_output_key = "output"
|
||||
|
||||
chains = []
|
||||
|
||||
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# agent mode
|
||||
tool_chains, chains_output_key = cls.get_agent_chains(
|
||||
tenant_id=tenant_id,
|
||||
agent_mode=agent_mode,
|
||||
rest_tokens=rest_tokens,
|
||||
memory=memory,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
chains += tool_chains
|
||||
|
||||
if chains_output_key:
|
||||
final_output_key = chains_output_key
|
||||
|
||||
if len(chains) == 0:
|
||||
return None
|
||||
|
||||
for chain in chains:
|
||||
chain = cast(Chain, chain)
|
||||
chain.callbacks.append(chain_callback_handler)
|
||||
|
||||
# build main chain
|
||||
overall_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=[first_input_key],
|
||||
output_variables=[final_output_key],
|
||||
memory=memory, # only for use the memory prompt input key
|
||||
)
|
||||
|
||||
return overall_chain
|
||||
|
||||
@classmethod
|
||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
|
||||
rest_tokens: int,
|
||||
memory: Optional[BaseChatMemory],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
# agent mode
|
||||
chains = []
|
||||
if agent_mode and agent_mode.get('enabled'):
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
pre_fixed_chains = []
|
||||
# agent_tools = []
|
||||
datasets = []
|
||||
for tool in tools:
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
if tool_type == 'sensitive-word-avoidance':
|
||||
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
|
||||
if chain:
|
||||
pre_fixed_chains.append(chain)
|
||||
elif tool_type == "dataset":
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
if dataset:
|
||||
datasets.append(dataset)
|
||||
|
||||
# add pre-fixed chains
|
||||
chains += pre_fixed_chains
|
||||
|
||||
if len(datasets) > 0:
|
||||
# tool to chain
|
||||
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
|
||||
tenant_id=tenant_id,
|
||||
datasets=datasets,
|
||||
conversation_message_task=conversation_message_task,
|
||||
rest_tokens=rest_tokens,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
chains.append(multi_dataset_router_chain)
|
||||
|
||||
final_output_key = cls.get_chains_output_key(chains)
|
||||
|
||||
return chains, final_output_key
|
||||
|
||||
@classmethod
|
||||
def get_chains_output_key(cls, chains: List[Chain]):
|
||||
if len(chains) > 0:
|
||||
return chains[-1].output_keys[0]
|
||||
return None
|
|
@ -1,198 +0,0 @@
|
|||
import math
|
||||
import re
|
||||
from typing import Mapping, List, Dict, Any, Optional
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import Extra
|
||||
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.tool.dataset_index_tool import DatasetTool
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
DEFAULT_K = 2
|
||||
CONTEXT_TOKENS_PERCENT = 0.3
|
||||
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
||||
Given a raw text input to a language model select the model prompt best suited for \
|
||||
the input. You will be given the names of the available prompts and a description of \
|
||||
what the prompt is best suited for. You may also revise the original input if you \
|
||||
think that revising it will ultimately lead to a better response from the language \
|
||||
model.
|
||||
|
||||
<< FORMATTING >>
|
||||
Return a markdown code snippet with a JSON object formatted to look like, \
|
||||
no any other string out of markdown code snippet:
|
||||
```json
|
||||
{{{{
|
||||
"destination": string \\ name of the prompt to use or "DEFAULT"
|
||||
"next_inputs": string \\ a potentially modified version of the original input
|
||||
}}}}
|
||||
```
|
||||
|
||||
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
||||
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
||||
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
||||
modifications are needed.
|
||||
|
||||
<< CANDIDATE PROMPTS >>
|
||||
{destinations}
|
||||
|
||||
<< INPUT >>
|
||||
{{input}}
|
||||
|
||||
<< OUTPUT >>
|
||||
"""
|
||||
|
||||
|
||||
class MultiDatasetRouterChain(Chain):
|
||||
"""Use a single chain to route an input to one of multiple candidate chains."""
|
||||
|
||||
router_chain: LLMRouterChain
|
||||
"""Chain for deciding a destination chain and the input to it."""
|
||||
dataset_tools: Mapping[str, DatasetTool]
|
||||
"""Map of name to candidate chains that inputs can be routed to."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the router chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.router_chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["text"]
|
||||
|
||||
@classmethod
|
||||
def from_datasets(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
datasets: List[Dataset],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Convenience constructor for instantiating from destination prompts."""
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
|
||||
else ('useful for when you want to answer queries about the ' + d.name))
|
||||
for d in datasets]
|
||||
destinations_str = "\n".join(destinations)
|
||||
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
||||
destinations=destinations_str
|
||||
)
|
||||
|
||||
router_prompt = PromptTemplate(
|
||||
template=router_template,
|
||||
input_variables=["input"],
|
||||
output_parser=RouterOutputParser(),
|
||||
)
|
||||
|
||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||
dataset_tools = {}
|
||||
for dataset in datasets:
|
||||
# fulfill description when it is empty
|
||||
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
||||
continue
|
||||
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||
if k == 0:
|
||||
continue
|
||||
|
||||
dataset_tool = DatasetTool(
|
||||
name=f"dataset-{dataset.id}",
|
||||
description=description,
|
||||
k=k,
|
||||
dataset=dataset,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
dataset_tools[str(dataset.id)] = dataset_tool
|
||||
|
||||
return cls(
|
||||
router_chain=router_chain,
|
||||
dataset_tools=dataset_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||
processing_rule = dataset.latest_process_rule
|
||||
if not processing_rule:
|
||||
return DEFAULT_K
|
||||
|
||||
if processing_rule.mode == "custom":
|
||||
rules = processing_rule.rules_dict
|
||||
if not rules:
|
||||
return DEFAULT_K
|
||||
|
||||
segmentation = rules["segmentation"]
|
||||
segment_max_tokens = segmentation["max_tokens"]
|
||||
else:
|
||||
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||
|
||||
# when rest_tokens is less than default context tokens
|
||||
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||
return rest_tokens // segment_max_tokens
|
||||
|
||||
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||
|
||||
# when context_limit_tokens is less than default context tokens, use default_k
|
||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||
return DEFAULT_K
|
||||
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||
return context_limit_tokens // segment_max_tokens
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if len(self.dataset_tools) == 0:
|
||||
return {"text": ''}
|
||||
elif len(self.dataset_tools) == 1:
|
||||
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
|
||||
|
||||
route = self.router_chain.route(inputs)
|
||||
|
||||
destination = ''
|
||||
if route.destination:
|
||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||
match = re.search(pattern, route.destination, re.IGNORECASE)
|
||||
if match:
|
||||
destination = match.group()
|
||||
|
||||
if not destination:
|
||||
return {"text": ''}
|
||||
elif destination in self.dataset_tools:
|
||||
return {"text": self.dataset_tools[destination].run(
|
||||
route.next_inputs['input']
|
||||
)}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Received invalid destination chain name '{destination}'"
|
||||
)
|
|
@ -1,51 +0,0 @@
|
|||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ToolChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
tool: BaseTool
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "tool_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
input = inputs[self.input_key]
|
||||
output = self.tool.run(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
input = inputs[self.input_key]
|
||||
output = await self.tool.arun(input, self.verbose)
|
||||
return {self.output_key: output}
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
@ -8,20 +9,21 @@ from langchain.llms import BaseLLM
|
|||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.constant import llm_constant
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||
DifyStdOutCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError
|
||||
from core.llm.fake import FakeLLM
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.chain.main_chain_builder import MainChainBuilder
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBStringBufferSharedMemory
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
|
@ -69,18 +71,33 @@ class Completion:
|
|||
streaming=streaming
|
||||
)
|
||||
|
||||
# build main chain include agent
|
||||
main_chain = MainChainBuilder.to_langchain_components(
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
tenant_id=app.tenant_id,
|
||||
agent_mode=app_model_config.agent_mode_dict,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
||||
conversation_message_task=conversation_message_task
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
chain_output = ''
|
||||
if main_chain:
|
||||
chain_output = main_chain.run(query)
|
||||
# parse sensitive_word_avoidance_chain
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
|
||||
# run the final llm
|
||||
try:
|
||||
|
@ -90,7 +107,7 @@ class Completion:
|
|||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
streaming=streaming
|
||||
|
@ -105,9 +122,20 @@ class Completion:
|
|||
|
||||
@classmethod
|
||||
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
chain_output: str,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
final_llm = FakeLLM(response=agent_execute_result.output,
|
||||
origin_llm=agent_execute_result.configuration.llm,
|
||||
streaming=streaming)
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
response = final_llm.generate([[HumanMessage(content=query)]])
|
||||
return response
|
||||
|
||||
final_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
|
@ -122,7 +150,7 @@ class Completion:
|
|||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
|
@ -142,16 +170,9 @@ class Completion:
|
|||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
# disable template string in query
|
||||
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
|
||||
# if query_params:
|
||||
# for query_param in query_params:
|
||||
# if query_param not in inputs:
|
||||
# inputs[query_param] = '{{' + query_param + '}}'
|
||||
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
@ -165,18 +186,13 @@ When answer to user:
|
|||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if chain_output else "")
|
||||
""" if agent_execute_result else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if chain_output:
|
||||
inputs['context'] = chain_output
|
||||
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
|
||||
# if context_params:
|
||||
# for context_param in context_params:
|
||||
# if context_param not in inputs:
|
||||
# inputs[context_param] = '{{' + context_param + '}}'
|
||||
if agent_execute_result:
|
||||
inputs['context'] = agent_execute_result.output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
|
@ -206,8 +222,8 @@ And answer according to the language of the user's question.
|
|||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
if agent_execute_result:
|
||||
human_inputs['context'] = agent_execute_result.output
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
|
@ -240,18 +256,10 @@ And answer according to the language of the user's question.
|
|||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
# disable template string in query
|
||||
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
|
||||
# if histories_params:
|
||||
# for histories_param in histories_params:
|
||||
# if histories_param not in human_inputs:
|
||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>"
|
||||
human_message_prompt += histories + "</histories>"
|
||||
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
||||
human_message_prompt += histories + "\n</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
|
@ -263,10 +271,13 @@ And answer according to the language of the user's question.
|
|||
|
||||
messages.append(human_message)
|
||||
|
||||
return messages, ['\nHuman:']
|
||||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return messages, ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
|
@ -277,8 +288,7 @@ And answer according to the language of the user's question.
|
|||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> \
|
||||
str:
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
|
@ -329,7 +339,7 @@ And answer according to the language of the user's question.
|
|||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
|
@ -379,6 +389,7 @@ And answer according to the language of the user's question.
|
|||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class ConversationMessageTask:
|
|||
message=self.message,
|
||||
conversation=self.conversation,
|
||||
chain_pub=False, # disabled currently
|
||||
agent_thought_pub=False # disabled currently
|
||||
agent_thought_pub=True
|
||||
)
|
||||
|
||||
def init(self):
|
||||
|
@ -69,6 +69,7 @@ class ConversationMessageTask:
|
|||
"suggested_questions": self.app_model_config.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
|
||||
"more_like_this": self.app_model_config.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
|
||||
"user_input_form": self.app_model_config.user_input_form_list,
|
||||
}
|
||||
|
||||
|
@ -207,7 +208,28 @@ class ConversationMessageTask:
|
|||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
|
||||
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
answer=agent_loop.completion,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
|
||||
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
|
||||
|
@ -222,34 +244,18 @@ class ConversationMessageTask:
|
|||
agent_answer_unit_price
|
||||
)
|
||||
|
||||
message_agent_loop = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
observation=agent_loop.tool_output,
|
||||
tool_process_data='', # currently not support
|
||||
message=agent_loop.prompt,
|
||||
message_token=loop_message_tokens,
|
||||
message_unit_price=agent_message_unit_price,
|
||||
answer=agent_loop.completion,
|
||||
answer_token=loop_answer_tokens,
|
||||
answer_unit_price=agent_answer_unit_price,
|
||||
latency=agent_loop.latency,
|
||||
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
|
||||
total_price=loop_total_price,
|
||||
currency=llm_constant.model_currency,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_loop)
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = llm_constant.model_currency
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_loop)
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_query_obj.dataset_id,
|
||||
|
@ -346,16 +352,14 @@ class PubHandler:
|
|||
content = {
|
||||
'event': 'agent_thought',
|
||||
'data': {
|
||||
'id': message_agent_thought.id,
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_agent_thought.message_chain_id,
|
||||
'agent_thought_id': message_agent_thought.id,
|
||||
'position': message_agent_thought.position,
|
||||
'thought': message_agent_thought.thought,
|
||||
'tool': message_agent_thought.tool,
|
||||
'tool_input': message_agent_thought.tool_input,
|
||||
'observation': message_agent_thought.observation,
|
||||
'answer': message_agent_thought.answer,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
|
@ -388,6 +392,15 @@ class PubHandler:
|
|||
def _is_stopped(self):
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def ping(cls, user: Union[Account | EndUser], task_id: str):
|
||||
content = {
|
||||
'event': 'ping'
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account | EndUser], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
|
@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
|
|||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
|
@ -22,22 +26,41 @@ class FileExtractor:
|
|||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
return cls.load_from_file(file_path, return_text, upload_file)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
|
|
59
api/core/llm/fake.py
Normal file
59
api/core/llm/fake.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import time
|
||||
from typing import List, Optional, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
|
||||
|
||||
|
||||
class FakeLLM(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
response: str
|
||||
origin_llm: Optional[BaseLanguageModel] = None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
return self.response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"response": self.response}
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if self.streaming:
|
||||
for token in output_str:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
time.sleep(0.01)
|
||||
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
llm_output = {"token_usage": {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
}}
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
|
@ -10,6 +10,9 @@ from core.llm.provider.errors import ValidateFailedError
|
|||
from models.provider import ProviderName
|
||||
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
|
||||
return []
|
||||
|
@ -50,9 +53,10 @@ class AzureProvider(BaseProvider):
|
|||
"""
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
if model_id == 'text-embedding-ada-002':
|
||||
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||
config['chunk_size'] = 1
|
||||
config['chunk_size'] = 16
|
||||
else:
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
|
@ -69,7 +73,7 @@ class AzureProvider(BaseProvider):
|
|||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
@ -78,7 +82,7 @@ class AzureProvider(BaseProvider):
|
|||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
@ -100,7 +104,7 @@ class AzureProvider(BaseProvider):
|
|||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'openai_api_version' not in config:
|
||||
config['openai_api_version'] = '2023-03-15-preview'
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
|
||||
self.check_embedding_model(credentials=config)
|
||||
except ValidateFailedError as e:
|
||||
|
@ -119,7 +123,7 @@ class AzureProvider(BaseProvider):
|
|||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
|
||||
from langchain.chat_models.openai import _convert_dict_to_message
|
||||
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
|
@ -9,6 +10,11 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
|||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
@ -71,3 +77,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
|||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||
inner_completion += token
|
||||
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
|
||||
if _function_call:
|
||||
if function_call is None:
|
||||
function_call = _function_call
|
||||
else:
|
||||
function_call["arguments"] += _function_call["arguments"]
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
"function_call": function_call,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Mapping, Any
|
||||
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
|
@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
|||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from httpx import Timeout
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||
|
||||
|
@ -12,6 +14,14 @@ class StreamableChatAnthropic(ChatAnthropic):
|
|||
Wrapper around Anthropic's large language model.
|
||||
"""
|
||||
|
||||
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
|
||||
|
||||
@root_validator()
|
||||
def prepare_params(cls, values: Dict) -> Dict:
|
||||
values['model_name'] = values.get('model')
|
||||
values['max_tokens'] = values.get('max_tokens_to_sample')
|
||||
return values
|
||||
|
||||
@handle_anthropic_exceptions
|
||||
def generate(
|
||||
self,
|
||||
|
@ -37,3 +47,16 @@ class StreamableChatAnthropic(ChatAnthropic):
|
|||
del params['presence_penalty']
|
||||
|
||||
return params
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
|
@ -3,7 +3,7 @@ import os
|
|||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, Union, Tuple
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
|
@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
|||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Any, Mapping
|
||||
from typing import Optional, List, Dict, Any, Mapping, Union, Tuple
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
|
@ -10,6 +10,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
|||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
|
277
api/core/orchestrator_rule_parser.py
Normal file
277
api/core/orchestrator_rule_parser.py
Normal file
|
@ -0,0 +1,277 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
"""Parse the orchestrator rule to entities."""
|
||||
|
||||
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
|
||||
self.tenant_id = tenant_id
|
||||
self.app_model_config = app_model_config
|
||||
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
|
||||
|
||||
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
|
||||
-> Optional[AgentExecutor]:
|
||||
if not self.app_model_config.agent_mode_dict:
|
||||
return None
|
||||
|
||||
agent_mode_config = self.app_model_config.agent_mode_dict
|
||||
model_dict = self.app_model_config.model_dict
|
||||
|
||||
chain = None
|
||||
if agent_mode_config and agent_mode_config.get('enabled'):
|
||||
tool_configs = agent_mode_config.get('tools', [])
|
||||
agent_model_name = model_dict.get('name', 'gpt-4')
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_name=agent_model_name,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
|
||||
chain_callback.agent_callback = agent_callback
|
||||
|
||||
agent_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=agent_model_name,
|
||||
temperature=0,
|
||||
max_tokens=1500,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
|
||||
|
||||
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||
if not isinstance(agent_llm, ChatOpenAI) \
|
||||
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
|
||||
summary_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=self.agent_summary_model_name,
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
conversation_message_task=conversation_message_task,
|
||||
model_name=self.agent_summary_model_name,
|
||||
rest_tokens=rest_tokens,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
llm=agent_llm,
|
||||
tools=tools,
|
||||
summary_llm=summary_llm,
|
||||
memory=memory,
|
||||
callbacks=[chain_callback, agent_callback],
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate"
|
||||
)
|
||||
|
||||
return AgentExecutor(agent_configuration)
|
||||
|
||||
return chain
|
||||
|
||||
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
|
||||
-> Optional[SensitiveWordAvoidanceChain]:
|
||||
"""
|
||||
Convert app sensitive word avoidance config to chain
|
||||
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not self.app_model_config.sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=callbacks,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
:param rest_tokens:
|
||||
:param tool_configs: app agent tool configs
|
||||
:param model_name:
|
||||
:param conversation_message_task:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
tools = []
|
||||
for tool_config in tool_configs:
|
||||
tool_type = list(tool_config.keys())[0]
|
||||
tool_val = list(tool_config.values())[0]
|
||||
if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
|
||||
continue
|
||||
|
||||
tool = None
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool(model_name)
|
||||
elif tool_type == "google_search":
|
||||
tool = self.to_google_search_tool()
|
||||
elif tool_type == "wikipedia":
|
||||
tool = self.to_wikipedia_tool()
|
||||
|
||||
if tool:
|
||||
tool.callbacks.extend(callbacks)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param rest_tokens:
|
||||
:param tool_config:
|
||||
:param conversation_message_task:
|
||||
:return:
|
||||
"""
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
return None
|
||||
|
||||
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
k=k,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
|
||||
:return:
|
||||
"""
|
||||
summary_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=model_name,
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
tool = WebReaderTool(
|
||||
llm=summary_llm,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_google_search_tool(self) -> Optional[BaseTool]:
|
||||
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||
if not func_kwargs:
|
||||
return None
|
||||
|
||||
tool = Tool(
|
||||
name="google_search",
|
||||
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||
"when you need to search for something you don't know or when your information "
|
||||
"is not up to date."
|
||||
"Input should be a search query.",
|
||||
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||
args_schema=OptimizedSerpAPIInput,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_wikipedia_tool(self) -> Optional[BaseTool]:
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
return WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||
DEFAULT_K = 2
|
||||
CONTEXT_TOKENS_PERCENT = 0.3
|
||||
processing_rule = dataset.latest_process_rule
|
||||
if not processing_rule:
|
||||
return DEFAULT_K
|
||||
|
||||
if processing_rule.mode == "custom":
|
||||
rules = processing_rule.rules_dict
|
||||
if not rules:
|
||||
return DEFAULT_K
|
||||
|
||||
segmentation = rules["segmentation"]
|
||||
segment_max_tokens = segmentation["max_tokens"]
|
||||
else:
|
||||
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||
|
||||
# when rest_tokens is less than default context tokens
|
||||
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||
return rest_tokens // segment_max_tokens
|
||||
|
||||
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||
|
||||
# when context_limit_tokens is less than default context tokens, use default_k
|
||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||
return DEFAULT_K
|
||||
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||
return context_limit_tokens // segment_max_tokens
|
|
@ -1,87 +0,0 @@
|
|||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetTool(BaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
dataset: Dataset
|
||||
k: int = 2
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
if self.dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=self.dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=self.dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
tool_input,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
}
|
||||
)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=self.dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = await vector_index.asearch(
|
||||
tool_input,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': 10
|
||||
}
|
||||
)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
105
api/core/tool/dataset_retriever_tool.py
Normal file
105
api/core/tool/dataset_retriever_tool.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
import re
|
||||
from typing import Type
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(BaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
k: int = 3
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description.replace('\n', '').replace('\r', '')
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description += '\nID of dataset MUST be ' + dataset.id
|
||||
return cls(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, dataset_id: str, query: str) -> str:
|
||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||
match = re.search(pattern, dataset_id, re.IGNORECASE)
|
||||
if match:
|
||||
dataset_id = match.group()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
if self.k > 0:
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
}
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
63
api/core/tool/provider/base.py
Normal file
63
api/core/tool/provider/base.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class BaseToolProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_validate(self, credentials: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and tool_name.
|
||||
"""
|
||||
query = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == self.tenant_id,
|
||||
ToolProvider.tool_name == self.get_provider_name().value
|
||||
)
|
||||
|
||||
if must_enabled:
|
||||
query = query.filter(ToolProvider.is_enabled == True)
|
||||
|
||||
return query.first()
|
||||
|
||||
def encrypt_token(self, token) -> str:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
||||
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
if obfuscated:
|
||||
return self._obfuscated_token(token)
|
||||
|
||||
return token
|
||||
|
||||
def _obfuscated_token(self, token: str) -> str:
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
2
api/core/tool/provider/errors.py
Normal file
2
api/core/tool/provider/errors.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
class ToolValidateFailedError(Exception):
|
||||
description = "Tool Provider Validate failed"
|
77
api/core/tool/provider/serpapi_provider.py
Normal file
77
api/core/tool/provider/serpapi_provider.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
|
||||
from models.tool import ToolProviderName
|
||||
|
||||
|
||||
class SerpAPIToolProvider(BaseToolProvider):
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
"""
|
||||
Returns the name of the provider.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return ToolProviderName.SERPAPI
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for SerpAPI as a dictionary.
|
||||
|
||||
:param obfuscated: obfuscate credentials if True
|
||||
:return:
|
||||
"""
|
||||
tool_provider = self.get_provider(must_enabled=True)
|
||||
if not tool_provider:
|
||||
return None
|
||||
|
||||
credentials = tool_provider.credentials
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
if credentials.get('api_key'):
|
||||
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
|
||||
|
||||
return credentials
|
||||
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials function kwargs as a dictionary.
|
||||
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
return {
|
||||
'serpapi_api_key': credentials.get('api_key')
|
||||
}
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
if 'api_key' not in credentials or not credentials.get('api_key'):
|
||||
raise ToolValidateFailedError("SerpAPI api_key is required.")
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
try:
|
||||
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
|
||||
except Exception as e:
|
||||
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
|
||||
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
|
||||
return credentials
|
43
api/core/tool/provider/tool_provider_service.py
Normal file
43
api/core/tool/provider/tool_provider_service.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
|
||||
|
||||
class ToolProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self._init_provider(tenant_id, provider_name)
|
||||
|
||||
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
|
||||
if provider_name == 'serpapi':
|
||||
return SerpAPIToolProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('tool provider {} not found'.format(provider_name))
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for Tool as a dictionary.
|
||||
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.get_credentials(obfuscated)
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.credentials_validate(credentials)
|
||||
|
||||
def encrypt_credentials(self, credentials: dict):
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.encrypt_credentials(credentials)
|
51
api/core/tool/serpapi_wrapper.py
Normal file
51
api/core/tool/serpapi_wrapper.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
from langchain import SerpAPIWrapper
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
|
||||
class OptimizedSerpAPIInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
|
||||
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict, num_results: int = 5) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
||||
res["answer_box"] = res["answer_box"][0]
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"shopping_results" in res.keys()
|
||||
and "title" in res["shopping_results"][0].keys()
|
||||
):
|
||||
toret = res["shopping_results"][:3]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
|
||||
toret = ""
|
||||
for result in res["organic_results"][:num_results]:
|
||||
if "link" in result:
|
||||
toret += "----------------\nlink: " + result["link"] + "\n"
|
||||
if "snippet" in result:
|
||||
toret += "snippet: " + result["snippet"] + "\n"
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
return "search result:\n" + toret
|
419
api/core/tool/web_reader_tool.py
Normal file
419
api/core/tool/web_reader_tool.py
Normal file
|
@ -0,0 +1,419 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, NavigableString, Comment, CData
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
class WebReaderToolInput(BaseModel):
|
||||
url: str = Field(..., description="URL of the website to read")
|
||||
summary: bool = Field(
|
||||
default=False,
|
||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
||||
"set it to true."
|
||||
)
|
||||
cursor: int = Field(
|
||||
default=0,
|
||||
description="Start reading from this character."
|
||||
"Use when the first response was truncated"
|
||||
"and you want to continue reading the page."
|
||||
"The value cannot exceed 24000.",
|
||||
)
|
||||
|
||||
|
||||
class WebReaderTool(BaseTool):
|
||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
||||
|
||||
name: str = "web_reader"
|
||||
args_schema: Type[BaseModel] = WebReaderToolInput
|
||||
description: str = "use this to read a website. " \
|
||||
"If you can answer the question based on the information provided, " \
|
||||
"there is no need to use."
|
||||
page_contents: str = None
|
||||
url: str = None
|
||||
max_chunk_length: int = 4000
|
||||
summary_chunk_tokens: int = 4000
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
llm: BaseLanguageModel
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
if not self.page_contents or self.url != url:
|
||||
page_contents = get_url(url)
|
||||
self.page_contents = page_contents
|
||||
self.url = url
|
||||
else:
|
||||
page_contents = self.page_contents
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
separators=self.summary_separators
|
||||
)
|
||||
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
# todo use cache
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
||||
|
||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
||||
|
||||
return page_contents
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor: cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
|
||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if head_response.status_code != 200:
|
||||
return "URL returned status code {}.".format(head_response.status_code)
|
||||
|
||||
# check content-type
|
||||
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return FileExtractor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
||||
if not a['plain_text'] or not a['plain_text'].strip():
|
||||
return get_url_from_newspaper3k(url)
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a['title'],
|
||||
authors=a['byline'],
|
||||
publish_date=a['date'],
|
||||
top_image="",
|
||||
text=a['plain_text'] if a['plain_text'] else "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_url_from_newspaper3k(url: str) -> str:
|
||||
|
||||
a = Article(url)
|
||||
a.download()
|
||||
a.parse()
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a.title,
|
||||
authors=a.authors,
|
||||
publish_date=a.publish_date,
|
||||
top_image=a.top_image,
|
||||
text=a.text,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
with open(article_json_path, "r", encoding="utf-8") as json_file:
|
||||
input_json = json.loads(json_file.read())
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if "title" in input_json and input_json["title"]:
|
||||
article_json["title"] = input_json["title"]
|
||||
if "byline" in input_json and input_json["byline"]:
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if "date" in input_json and input_json["date"]:
|
||||
article_json["date"] = input_json["date"]
|
||||
if "content" in input_json and input_json["content"]:
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if "textContent" in input_json and input_json["textContent"]:
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, 'html.parser')
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(['ul', 'ol'])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalise it
|
||||
plain_text = normalise_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, 'html.parser')
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes)
|
||||
for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalise the extracted text string to a canonical representation
|
||||
plain_text = normalise_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalise_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate(
|
||||
[c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(
|
||||
stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalise_text(text):
|
||||
"""Normalise unicode and whitespace."""
|
||||
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
|
||||
text = strip_control_characters(text)
|
||||
text = normalise_unicode(text)
|
||||
text = normalise_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
|
||||
retained_chars = ['\t', '\n', '\r', '\f']
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
||||
|
||||
|
||||
def normalise_unicode(text):
|
||||
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalise_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def is_leaf(element):
|
||||
return (element.name in ['p', 'li'])
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(
|
||||
filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode('utf-8'))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
|
@ -0,0 +1,32 @@
|
|||
"""add is_universal in apps
|
||||
|
||||
Revision ID: 2beac44e5f5f
|
||||
Revises: d3d503a3471c
|
||||
Create Date: 2023-07-07 12:11:29.156057
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2beac44e5f5f'
|
||||
down_revision = 'a5b56fb053ef'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_universal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.drop_column('is_universal')
|
||||
|
||||
# ### end Alembic commands ###
|
44
api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
Normal file
44
api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
"""add tool providers
|
||||
|
||||
Revision ID: 7ce5a52e4eee
|
||||
Revises: 2beac44e5f5f
|
||||
Create Date: 2023-07-10 10:26:50.074515
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7ce5a52e4eee'
|
||||
down_revision = '2beac44e5f5f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('tool_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||
sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('sensitive_word_avoidance')
|
||||
|
||||
op.drop_table('tool_providers')
|
||||
# ### end Alembic commands ###
|
|
@ -40,6 +40,7 @@ class App(db.Model):
|
|||
api_rph = db.Column(db.Integer, nullable=False)
|
||||
is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
@ -88,6 +89,7 @@ class AppModelConfig(db.Model):
|
|||
user_input_form = db.Column(db.Text)
|
||||
pre_prompt = db.Column(db.Text)
|
||||
agent_mode = db.Column(db.Text)
|
||||
sensitive_word_avoidance = db.Column(db.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
@ -116,14 +118,35 @@ class AppModelConfig(db.Model):
|
|||
def more_like_this_dict(self) -> dict:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> dict:
|
||||
return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \
|
||||
else {"enabled": False, "words": [], "canned_response": []}
|
||||
|
||||
@property
|
||||
def user_input_form_list(self) -> dict:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict:
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "tools": []}
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"provider": "",
|
||||
"model_id": "",
|
||||
"configs": {},
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
|
||||
"speech_to_text": self.speech_to_text_dict,
|
||||
"more_like_this": self.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
|
||||
"model": self.model_dict,
|
||||
"user_input_form": self.user_input_form_list,
|
||||
"pre_prompt": self.pre_prompt,
|
||||
"agent_mode": self.agent_mode_dict
|
||||
}
|
||||
|
||||
class RecommendedApp(db.Model):
|
||||
__tablename__ = 'recommended_apps'
|
||||
|
@ -237,6 +260,9 @@ class Conversation(db.Model):
|
|||
if 'speech_to_text' in override_model_configs else {"enabled": False}
|
||||
model_config['more_like_this'] = override_model_configs['more_like_this'] \
|
||||
if 'more_like_this' in override_model_configs else {"enabled": False}
|
||||
model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \
|
||||
if 'sensitive_word_avoidance' in override_model_configs \
|
||||
else {"enabled": False, "words": [], "canned_response": []}
|
||||
model_config['user_input_form'] = override_model_configs['user_input_form']
|
||||
else:
|
||||
model_config['configs'] = override_model_configs
|
||||
|
@ -253,6 +279,7 @@ class Conversation(db.Model):
|
|||
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
|
||||
model_config['speech_to_text'] = app_model_config.speech_to_text_dict
|
||||
model_config['more_like_this'] = app_model_config.more_like_this_dict
|
||||
model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
|
||||
model_config['user_input_form'] = app_model_config.user_input_form_list
|
||||
|
||||
model_config['model_id'] = self.model_id
|
||||
|
@ -393,6 +420,11 @@ class Message(db.Model):
|
|||
def in_debug_mode(self):
|
||||
return self.override_model_configs is not None
|
||||
|
||||
@property
|
||||
def agent_thoughts(self):
|
||||
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id)\
|
||||
.order_by(MessageAgentThought.position.asc()).all()
|
||||
|
||||
|
||||
class MessageFeedback(db.Model):
|
||||
__tablename__ = 'message_feedbacks'
|
||||
|
|
47
api/models/tool.py
Normal file
47
api/models/tool.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class ToolProviderName(Enum):
|
||||
SERPAPI = 'serpapi'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ToolProviderName:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ToolProvider(db.Model):
|
||||
__tablename__ = 'tool_providers'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
tool_name = db.Column(db.String(40), nullable=False)
|
||||
encrypted_credentials = db.Column(db.Text, nullable=True)
|
||||
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def credentials_is_set(self):
|
||||
"""
|
||||
Returns True if the encrypted_config is not None, indicating that the token is set.
|
||||
"""
|
||||
return self.encrypted_credentials is not None
|
||||
|
||||
@property
|
||||
def credentials(self):
|
||||
"""
|
||||
Returns the decrypted config.
|
||||
"""
|
||||
return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None
|
|
@ -10,8 +10,8 @@ flask-session2==1.3.1
|
|||
flask-cors==3.0.10
|
||||
gunicorn~=20.1.0
|
||||
gevent~=22.10.2
|
||||
langchain==0.0.230
|
||||
openai~=0.27.5
|
||||
langchain==0.0.239
|
||||
openai~=0.27.8
|
||||
psycopg2-binary~=2.9.6
|
||||
pycryptodome==3.17
|
||||
python-dotenv==1.0.0
|
||||
|
@ -36,3 +36,8 @@ pypdfium2==4.16.0
|
|||
resend~=0.5.1
|
||||
pyjwt~=2.6.0
|
||||
anthropic~=0.3.4
|
||||
newspaper3k==0.2.8
|
||||
google-api-python-client==2.90.0
|
||||
wikipedia==1.4.0
|
||||
readabilipy==0.2.0
|
||||
google-search-results==2.4.2
|
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.constant import llm_constant
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
@ -31,6 +32,16 @@ MODELS_BY_APP_MODE = {
|
|||
]
|
||||
}
|
||||
|
||||
SUPPORT_AGENT_MODELS = [
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
]
|
||||
|
||||
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia"]
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@staticmethod
|
||||
def is_dataset_exists(account: Account, dataset_id: str) -> bool:
|
||||
|
@ -58,7 +69,8 @@ class AppModelConfigService:
|
|||
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
llm_constant.max_context_token_length[model_name]:
|
||||
raise ValueError(
|
||||
"max_tokens must be an integer greater than 0 and not exceeding the maximum value of the corresponding model")
|
||||
"max_tokens must be an integer greater than 0 "
|
||||
"and not exceeding the maximum value of the corresponding model")
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
|
@ -148,11 +160,6 @@ class AppModelConfigService:
|
|||
|
||||
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
|
||||
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
|
||||
|
||||
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
|
||||
raise ValueError("provider not support speech to text")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
|
@ -169,6 +176,33 @@ class AppModelConfigService:
|
|||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
# sensitive_word_avoidance
|
||||
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
|
||||
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
|
||||
|
||||
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
|
||||
config["sensitive_word_avoidance"]["words"] = ""
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
|
||||
raise ValueError("words in sensitive_word_avoidance must be of string type")
|
||||
|
||||
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
|
||||
config["sensitive_word_avoidance"]["canned_response"] = ""
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
|
||||
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
|
||||
|
||||
# model
|
||||
if 'model' not in config:
|
||||
raise ValueError("model is required")
|
||||
|
@ -274,6 +308,12 @@ class AppModelConfigService:
|
|||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]:
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]:
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
|
@ -282,8 +322,8 @@ class AppModelConfigService:
|
|||
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key not in ["sensitive-word-avoidance", "dataset"]:
|
||||
raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'")
|
||||
if key not in SUPPORT_TOOLS:
|
||||
raise ValueError("Keys in agent_mode.tools must be in the specified tool list")
|
||||
|
||||
tool_item = tool[key]
|
||||
|
||||
|
@ -293,19 +333,7 @@ class AppModelConfigService:
|
|||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if key == "sensitive-word-avoidance":
|
||||
if "words" not in tool_item or not tool_item["words"]:
|
||||
tool_item["words"] = ""
|
||||
|
||||
if not isinstance(tool_item["words"], str):
|
||||
raise ValueError("words in sensitive-word-avoidance must be of string type")
|
||||
|
||||
if "canned_response" not in tool_item or not tool_item["canned_response"]:
|
||||
tool_item["canned_response"] = ""
|
||||
|
||||
if not isinstance(tool_item["canned_response"], str):
|
||||
raise ValueError("canned_response in sensitive-word-avoidance must be of string type")
|
||||
elif key == "dataset":
|
||||
if key == "dataset":
|
||||
if 'id' not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
|
@ -324,6 +352,7 @@ class AppModelConfigService:
|
|||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
|
|
|
@ -37,6 +37,8 @@ class CompletionService:
|
|||
if not query:
|
||||
raise ValueError('query is required')
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
|
||||
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
|
||||
|
||||
conversation = None
|
||||
|
@ -140,6 +142,7 @@ class CompletionService:
|
|||
suggested_questions=json.dumps(model_config['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
|
||||
more_like_this=json.dumps(model_config['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_config['model']),
|
||||
user_input_form=json.dumps(model_config['user_input_form']),
|
||||
pre_prompt=model_config['pre_prompt'],
|
||||
|
@ -171,7 +174,7 @@ class CompletionService:
|
|||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 5 minutes to close the thread
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
@ -179,9 +182,9 @@ class CompletionService:
|
|||
@classmethod
|
||||
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
|
||||
if isinstance(user, Account):
|
||||
user = db.session.query(Account).get(user.id)
|
||||
user = db.session.query(Account).filter(Account.id == user.id).first()
|
||||
elif isinstance(user, EndUser):
|
||||
user = db.session.query(EndUser).get(user.id)
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
|
||||
else:
|
||||
raise Exception("Unknown user type")
|
||||
|
||||
|
@ -226,12 +229,15 @@ class CompletionService:
|
|||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
||||
# wait for 5 minutes to close the thread
|
||||
timeout = 300
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
def close_pubsub():
|
||||
sleep_iterations = 0
|
||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
||||
PubHandler.ping(user, generate_task_id)
|
||||
|
||||
time.sleep(1)
|
||||
sleep_iterations += 1
|
||||
|
||||
|
@ -369,7 +375,7 @@ class CompletionService:
|
|||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
|
@ -418,6 +424,10 @@ class CompletionService:
|
|||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'ping':
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
yield "data: " + json.dumps(result) + "\n\n"
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
logging.exception(e)
|
||||
|
@ -467,16 +477,14 @@ class CompletionService:
|
|||
def get_agent_thought_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'agent_thought',
|
||||
'id': data.get('agent_thought_id'),
|
||||
'id': data.get('id'),
|
||||
'chain_id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'position': data.get('position'),
|
||||
'thought': data.get('thought'),
|
||||
'tool': data.get('tool'), # todo use real dataset obj replace it
|
||||
'tool': data.get('tool'),
|
||||
'tool_input': data.get('tool_input'),
|
||||
'observation': data.get('observation'),
|
||||
'answer': data.get('answer') if not data.get('thought') else '',
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user