mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: optimize the efficiency of generating chatbot conversation name (#3472)
This commit is contained in:
parent
8f8e9de601
commit
12f1ce4794
|
@ -98,6 +98,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
)
|
)
|
||||||
|
|
||||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||||
|
self._conversation_name_generate_thread = None
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
|
@ -108,6 +109,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
db.session.refresh(self._user)
|
db.session.refresh(self._user)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
# start generate conversation name thread
|
||||||
|
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||||
|
self._conversation,
|
||||||
|
self._application_generate_entity.query
|
||||||
|
)
|
||||||
|
|
||||||
generator = self._process_stream_response()
|
generator = self._process_stream_response()
|
||||||
if self._stream:
|
if self._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
|
@ -278,6 +285,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if self._conversation_name_generate_thread:
|
||||||
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
def _save_message(self) -> None:
|
def _save_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
Save message.
|
Save message.
|
||||||
|
|
|
@ -97,6 +97,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._conversation_name_generate_thread = None
|
||||||
|
|
||||||
def process(self) -> Union[
|
def process(self) -> Union[
|
||||||
ChatbotAppBlockingResponse,
|
ChatbotAppBlockingResponse,
|
||||||
CompletionAppBlockingResponse,
|
CompletionAppBlockingResponse,
|
||||||
|
@ -110,6 +112,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||||
db.session.refresh(self._message)
|
db.session.refresh(self._message)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||||
|
# start generate conversation name thread
|
||||||
|
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||||
|
self._conversation,
|
||||||
|
self._application_generate_entity.query
|
||||||
|
)
|
||||||
|
|
||||||
generator = self._process_stream_response()
|
generator = self._process_stream_response()
|
||||||
if self._stream:
|
if self._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
|
@ -256,6 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if self._conversation_name_generate_thread:
|
||||||
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
def _save_message(self) -> None:
|
def _save_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
Save message.
|
Save message.
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
from threading import Thread
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
AgentChatAppGenerateEntity,
|
AgentChatAppGenerateEntity,
|
||||||
|
@ -19,9 +22,10 @@ from core.app.entities.task_entities import (
|
||||||
MessageReplaceStreamResponse,
|
MessageReplaceStreamResponse,
|
||||||
MessageStreamResponse,
|
MessageStreamResponse,
|
||||||
)
|
)
|
||||||
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import MessageAnnotation, MessageFile
|
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +38,59 @@ class MessageCycleManage:
|
||||||
]
|
]
|
||||||
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
||||||
|
|
||||||
|
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
|
||||||
|
"""
|
||||||
|
Generate conversation name.
|
||||||
|
:param conversation: conversation
|
||||||
|
:param query: query
|
||||||
|
:return: thread
|
||||||
|
"""
|
||||||
|
is_first_message = self._application_generate_entity.conversation_id is None
|
||||||
|
extras = self._application_generate_entity.extras
|
||||||
|
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||||
|
|
||||||
|
if auto_generate_conversation_name and is_first_message:
|
||||||
|
# start generate thread
|
||||||
|
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'conversation_id': conversation.id,
|
||||||
|
'query': query
|
||||||
|
})
|
||||||
|
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return thread
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_conversation_name_worker(self,
|
||||||
|
flask_app: Flask,
|
||||||
|
conversation_id: str,
|
||||||
|
query: str):
|
||||||
|
with flask_app.app_context():
|
||||||
|
# get conversation and message
|
||||||
|
conversation = (
|
||||||
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.id == conversation_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation.mode != AppMode.COMPLETION.value:
|
||||||
|
app_model = conversation.app
|
||||||
|
if not app_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
# generate conversation name
|
||||||
|
try:
|
||||||
|
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||||
|
conversation.name = name
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
db.session.merge(conversation)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||||
"""
|
"""
|
||||||
Handle annotation reply.
|
Handle annotation reply.
|
||||||
|
|
|
@ -5,7 +5,6 @@ from .create_installed_app_when_app_created import handle
|
||||||
from .create_site_record_when_app_created import handle
|
from .create_site_record_when_app_created import handle
|
||||||
from .deduct_quota_when_messaeg_created import handle
|
from .deduct_quota_when_messaeg_created import handle
|
||||||
from .delete_installed_app_when_app_deleted import handle
|
from .delete_installed_app_when_app_deleted import handle
|
||||||
from .generate_conversation_name_when_first_message_created import handle
|
|
||||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||||
from .update_provider_last_used_at_when_messaeg_created import handle
|
from .update_provider_last_used_at_when_messaeg_created import handle
|
||||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
|
||||||
from events.message_event import message_was_created
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import AppMode
|
|
||||||
|
|
||||||
|
|
||||||
@message_was_created.connect
|
|
||||||
def handle(sender, **kwargs):
|
|
||||||
message = sender
|
|
||||||
conversation = kwargs.get('conversation')
|
|
||||||
is_first_message = kwargs.get('is_first_message')
|
|
||||||
extras = kwargs.get('extras', {})
|
|
||||||
|
|
||||||
auto_generate_conversation_name = True
|
|
||||||
if extras:
|
|
||||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
|
||||||
|
|
||||||
if auto_generate_conversation_name and is_first_message:
|
|
||||||
if conversation.mode != AppMode.COMPLETION.value:
|
|
||||||
app_model = conversation.app
|
|
||||||
if not app_model:
|
|
||||||
return
|
|
||||||
|
|
||||||
# generate conversation name
|
|
||||||
try:
|
|
||||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
|
|
||||||
conversation.name = name
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
db.session.merge(conversation)
|
|
||||||
db.session.commit()
|
|
Loading…
Reference in New Issue
Block a user