Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost 2024-01-02 23:42:00 +08:00 committed by GitHub
parent e91dd28a76
commit d069c668f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
807 changed files with 171310 additions and 23806 deletions

View File

@ -0,0 +1,58 @@
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
- feat/model-runtime
jobs:
test:
runs-on: ubuntu-latest
env:
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
CHATGLM_API_BASE: http://a.abc.com:11451
XINFERENCE_SERVER_URL: http://a.abc.com:11451
XINFERENCE_GENERATION_MODEL_UID: generate
XINFERENCE_CHAT_MODEL_UID: chat
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
XINFERENCE_RERANK_MODEL_UID: rerank
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
MOCK_SWITCH: true
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py

View File

@ -1,38 +0,0 @@
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/unit_tests

View File

@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help! Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### Provider Integrations
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
### i18n (Internationalization) Support ### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know. We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.

View File

@ -4,6 +4,21 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{
"name": "Python: Celery",
"type": "python",
"request": "launch",
"module": "celery",
"justMyCode": true,
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
"envFile": "${workspaceFolder}/.env",
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"console": "integratedTerminal"
},
{ {
"name": "Python: Flask", "name": "Python: Flask",
"type": "python", "type": "python",

View File

@ -34,9 +34,6 @@ RUN apt-get update \
COPY --from=base /pkg /usr/local COPY --from=base /pkg /usr/local
COPY . /app/api/ COPY . /app/api/
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true
COPY docker/entrypoint.sh /entrypoint.sh COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh RUN chmod +x /entrypoint.sh

View File

@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus': # if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent() grpc.experimental.gevent.init_gevent()
import langchain
langchain.verbose = True
import time import time
import logging import logging
@ -18,9 +21,8 @@ import threading
from flask import Flask, request, Response from flask import Flask, request, Response
from flask_cors import CORS from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_code_based_extension ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app) register_blueprints(app)
register_commands(app) register_commands(app)
hosted.init_app(app)
return app return app
@ -95,6 +95,7 @@ def initialize_extensions(app):
ext_celery.init_app(app) ext_celery.init_app(app)
ext_login.init_app(app) ext_login.init_app(app)
ext_mail.init_app(app) ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app) ext_sentry.init_app(app)
@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
if request.blueprint == 'console': if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '') auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header: if not auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') auth_token = request.args.get('_token')
auth_scheme, auth_token = auth_header.split(None, 1) if not auth_token:
auth_scheme = auth_scheme.lower() raise Unauthorized('Invalid Authorization token.')
if auth_scheme != 'bearer': else:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id') user_id = decoded.get('user_id')

View File

@ -12,16 +12,12 @@ import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm from tqdm import tqdm
from flask import current_app, Flask from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory from core.model_manager import ModelManager
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding from core.model_runtime.entities.model_entities import ModelType
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from libs.password import password_pattern, valid_password, hash_password from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate from libs.helper import email as email_validate
from extensions.ext_database import db from extensions.ext_database import db
@ -327,6 +323,8 @@ def create_qdrant_indexes():
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
if dataset.index_struct_dict: if dataset.index_struct_dict:
@ -334,19 +332,23 @@ def create_qdrant_indexes():
try: try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id)) click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try: try:
embedding_model = ModelFactory.get_embedding_model( embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except Exception: except Exception:
try: try:
embedding_model = ModelFactory.get_embedding_model( embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
) )
dataset.embedding_model = embedding_model.name dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.model_provider.provider_name dataset.embedding_model_provider = embedding_model.provider
except Exception: except Exception:
provider = Provider( provider = Provider(
id='provider_id', id='provider_id',
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,

View File

@ -87,7 +87,7 @@ class Config:
# ------------------------ # ------------------------
# General Configurations. # General Configurations.
# ------------------------ # ------------------------
self.CURRENT_VERSION = "0.3.34" self.CURRENT_VERSION = "0.4.0"
self.COMMIT_SHA = get_env('COMMIT_SHA') self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED" self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.DEPLOY_ENV = get_env('DEPLOY_ENV')

View File

@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers # Import workspace controllers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models from .workspace import workspace, members, model_providers, account, tool_providers, models
# Import explore controllers # Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio

View File

@ -4,6 +4,10 @@ import logging
from datetime import datetime from datetime import datetime
from flask_login import current_user from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,9 +17,7 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from events.app_event import app_was_created, app_was_deleted from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \ from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site app_detail_fields_with_site
@ -73,39 +75,41 @@ class AppListApi(Resource):
raise Forbidden() raise Forbidden()
try: try:
default_model = ModelFactory.get_text_generation_model( provider_manager = ProviderManager()
tenant_id=current_user.current_tenant_id default_model_entity = provider_manager.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
) )
except (ProviderTokenNotInitError, LLMBadRequestError): except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None default_model_entity = None
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
default_model = None default_model_entity = None
if args['model_config'] is not None: if args['model_config'] is not None:
# validate config # validate config
model_config_dict = args['model_config'] model_config_dict = args['model_config']
# get model provider # get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider( model_manager = ModelManager()
current_user.current_tenant_id, model_instance = model_manager.get_default_model_instance(
model_config_dict["model"]["provider"] tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
) )
if not model_provider: if not model_instance:
if not default_model: raise ProviderNotInitializeError(
raise ProviderNotInitializeError( f"No Default System Reasoning Model available. Please configure "
f"No Default System Reasoning Model available. Please configure " f"in the Settings -> Model Provider.")
f"in the Settings -> Model Provider.") else:
else: model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name model_config_dict["model"]["name"] = model_instance.model
model_config_dict["model"]["name"] = default_model.name
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=model_config_dict, config=model_config_dict,
mode=args['mode'] app_mode=args['mode']
) )
app = App( app = App(
@ -129,21 +133,27 @@ class AppListApi(Resource):
app_model_config = AppModelConfig(**model_config_template['model_config']) app_model_config = AppModelConfig(**model_config_template['model_config'])
# get model provider # get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider( model_manager = ModelManager()
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)
if not model_provider: try:
if not default_model: model_instance = model_manager.get_default_model_instance(
raise ProviderNotInitializeError( tenant_id=current_user.current_tenant_id,
f"No Default System Reasoning Model available. Please configure " model_type=ModelType.LLM
f"in the Settings -> Model Provider.") )
else: except ProviderTokenNotInitError:
model_dict = app_model_config.model_dict raise ProviderNotInitializeError(
model_dict['provider'] = default_model.model_provider.provider_name f"No Default System Reasoning Model available. Please configure "
model_dict['name'] = default_model.name f"in the Settings -> Model Provider.")
app_model_config.model = json.dumps(model_dict)
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = model_instance.provider
model_dict['name'] = model_instance.model
app_model_config.model = json.dumps(model_dict)
app.name = args['name'] app.name = args['name']
app.mode = args['mode'] app.mode = args['mode']

View File

@ -2,6 +2,8 @@
import logging import logging
from flask import request from flask import request
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource from flask_restful import Resource
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e

View File

@ -5,6 +5,10 @@ from typing import Generator, Union
import flask_login import flask_login
from flask import Response, stream_with_context from flask import Response, stream_with_context
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
ProviderModelCurrentlyNotSupportError ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value from libs.helper import uuid_value
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
app_model=app_model, app_model=app_model,
user=account, user=account,
args=args, args=args,
from_source='console', invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming,
is_model_config_override=True is_model_config_override=True
) )
@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
account = flask_login.current_user account = flask_login.current_user
PubHandler.stop(account, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
app_model=app_model, app_model=app_model,
user=account, user=account,
args=args, args=args,
from_source='console', invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming,
is_model_config_override=True is_model_config_override=True
) )
@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e: yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception: except Exception:
@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
account = flask_login.current_user account = flask_login.current_user
PubHandler.stop(account, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200

View File

@ -1,4 +1,6 @@
from flask_login import current_user from flask_login import current_user
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
class RuleGenerateApi(Resource): class RuleGenerateApi(Resource):
@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
return rules return rules

View File

@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.entities.application_entities import InvokeFrom
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
from fields.conversation_fields import message_detail_fields, annotation_fields from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
app_model = _get_app(app_id, 'completion') app_model = _get_app(app_id, 'completion')
try: try:
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming) response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
)
return compact_response(response) return compact_response(response)
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps( yield "data: " + json.dumps(
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except Exception: except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -31,7 +31,7 @@ class ModelConfigResource(Resource):
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=request.json, config=request.json,
mode=app.mode app_mode=app.mode
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(

View File

@ -4,6 +4,8 @@ from flask import request, current_app
from flask_login import current_user from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields from controllers.console.apikey import api_key_list, api_key_fields
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required from libs.login import login_required
from flask_restful import Resource, reqparse, marshal, marshal_with from flask_restful import Resource, reqparse, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.models.entity.model_params import ModelType
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
@ -23,7 +24,6 @@ from extensions.ext_database import db
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
from models.model import UploadFile, ApiToken from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
def _validate_name(name): def _validate_name(name):
@ -55,16 +55,20 @@ class DatasetListApi(Resource):
current_user.current_tenant_id, current_user) current_user.current_tenant_id, current_user)
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_manager = ProviderManager()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, configurations = provider_manager.get_configurations(
ModelType.EMBEDDINGS.value) tenant_id=current_user.current_tenant_id
# if len(valid_model_list) == 0: )
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider " embedding_models = configurations.get_models(
# f"in the Settings -> Model Provider.") model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = [] model_names = []
for valid_model in valid_model_list: for embedding_model in embedding_models:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields) data = marshal(datasets, dataset_detail_fields)
for item in data: for item in data:
if item['indexing_technique'] == 'high_quality': if item['indexing_technique'] == 'high_quality':
@ -75,6 +79,7 @@ class DatasetListApi(Resource):
item['embedding_available'] = False item['embedding_available'] = False
else: else:
item['embedding_available'] = True item['embedding_available'] = True
response = { response = {
'data': data, 'data': data,
'has_more': len(datasets) == limit, 'has_more': len(datasets) == limit,
@ -130,13 +135,20 @@ class DatasetApi(Resource):
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = marshal(dataset, dataset_detail_fields)
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_manager = ProviderManager()
# get valid model list configurations = provider_manager.get_configurations(
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, tenant_id=current_user.current_tenant_id
ModelType.EMBEDDINGS.value) )
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = [] model_names = []
for valid_model in valid_model_list: for embedding_model in embedding_models:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data['indexing_technique'] == 'high_quality': if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names: if item_model in model_names:

View File

@ -2,8 +2,12 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from flask import request, current_app from flask import request
from flask_login import current_user from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.login import login_required from libs.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc from sqlalchemy import desc, asc
@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \ from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields dataset_and_document_fields, document_status_fields
@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
args = parser.parse_args() args = parser.parse_args()
if args['indexing_technique'] == 'high_quality': if args['indexing_technique'] == 'high_quality':
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
tenant_id=current_user.current_tenant_id model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
) )
except LLMBadRequestError: except InvokeAuthorizationError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")

View File

@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.login import login_required from libs.login import login_required
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(

View File

@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError LLMBadRequestError
from fields.hit_testing_fields import hit_testing_record_fields from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService

View File

@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \ NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e

View File

@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler from core.application_queue_manager import ApplicationQueueManager
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.entities.application_entities import InvokeFrom
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
args=args, args=args,
from_source='console', invoke_from=InvokeFrom.EXPLORE,
streaming=streaming streaming=streaming
) )
@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
PubHandler.stop(current_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
args=args, args=args,
from_source='console', invoke_from=InvokeFrom.EXPLORE,
streaming=streaming streaming=streaming
) )
@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
PubHandler.stop(current_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@ -5,7 +5,7 @@ from typing import Generator, Union
from flask import stream_with_context, Response from flask import stream_with_context, Response
from flask_login import current_user from flask_login import current_user
from flask_restful import reqparse, fields, marshal_with from flask_restful import reqparse, marshal_with
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound, InternalServerError from werkzeug.exceptions import NotFound, InternalServerError
@ -13,12 +13,14 @@ import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \ from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.entities.application_entities import InvokeFrom
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
try: try:
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming) response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
return compact_response(response) return compact_response(response)
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except Exception: except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \ NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e

View File

@ -12,9 +12,10 @@ from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.conversation_message_task import PubHandler from core.application_queue_manager import ApplicationQueueManager
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ from core.entities.application_entities import InvokeFrom
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
args=args, args=args,
from_source='console', invoke_from=InvokeFrom.EXPLORE,
streaming=True, streaming=True,
is_model_config_override=True, is_model_config_override=True,
) )
@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
class UniversalChatStopApi(UniversalChatResource): class UniversalChatStopApi(UniversalChatResource):
def post(self, universal_app, task_id): def post(self, universal_app, task_id):
PubHandler.stop(current_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except Exception: except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -1,16 +1,19 @@
import io
from flask import send_file
from flask_login import current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError from core.model_runtime.entities.model_entities import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from services.provider_service import ProviderService from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.billing_service import BillingService from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
class ModelProviderListApi(Resource): class ModelProviderListApi(Resource):
@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args') parser.add_argument('model_type', type=str, required=False, nullable=True,
choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() model_provider_service = ModelProviderService()
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type')) provider_list = model_provider_service.get_provider_list(
tenant_id=tenant_id,
model_type=args.get('model_type')
)
return provider_list return jsonable_encoder({"data": provider_list})
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credentials(
tenant_id=tenant_id,
provider=provider
)
return {
"credentials": credentials
}
class ModelProviderValidateApi(Resource): class ModelProviderValidateApi(Resource):
@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_name: str): def post(self, provider: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json') parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
result = True result = True
error = None error = None
try: try:
provider_service.custom_provider_config_validate( model_provider_service.provider_credentials_validate(
provider_name=provider_name, tenant_id=tenant_id,
config=args['config'] provider=provider,
credentials=args['credentials']
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
result = False result = False
@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
return response return response
class ModelProviderUpdateApi(Resource): class ModelProviderApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_name: str): def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json') parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() model_provider_service = ModelProviderService()
try: try:
provider_service.save_custom_provider_config( model_provider_service.save_provider_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider_name=provider_name, provider=provider,
config=args['config'] credentials=args['credentials']
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider_name: str): def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
provider_service = ProviderService() model_provider_service = ModelProviderService()
provider_service.delete_custom_provider( model_provider_service.remove_provider_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider_name=provider_name provider=provider
) )
return {'result': 'success'}, 204 return {'result': 'success'}, 204
class ModelProviderModelValidateApi(Resource): class ModelProviderIconApi(Resource):
"""
Get model provider icon
"""
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_name: str): def get(self, provider: str, icon_type: str, lang: str):
parser = reqparse.RequestParser() model_provider_service = ModelProviderService()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') icon, mimetype = model_provider_service.get_model_provider_icon(
parser.add_argument('model_type', type=str, required=True, nullable=False, provider=provider,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') icon_type=icon_type,
parser.add_argument('config', type=dict, required=True, nullable=False, location='json') lang=lang
args = parser.parse_args()
provider_service = ProviderService()
result = True
error = None
try:
provider_service.custom_provider_model_config_validate(
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
try:
provider_service.add_or_save_custom_provider_model_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.delete_custom_provider_model(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type']
) )
return {'result': 'success'}, 204 return send_file(io.BytesIO(icon), mimetype=mimetype)
class PreferredProviderTypeUpdateApi(Resource): class PreferredProviderTypeUpdateApi(Resource):
@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_name: str): def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
choices=['system', 'custom'], location='json') choices=['system', 'custom'], location='json')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() model_provider_service = ModelProviderService()
provider_service.switch_preferred_provider( model_provider_service.switch_preferred_provider(
tenant_id=current_user.current_tenant_id, tenant_id=tenant_id,
provider_name=provider_name, provider=provider,
preferred_provider_type=args['preferred_provider_type'] preferred_provider_type=args['preferred_provider_type']
) )
return {'result': 'success'} return {'result': 'success'}
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
provider_service = ProviderService()
try:
parameter_rules = provider_service.get_model_parameter_rules(
tenant_id=current_user.current_tenant_id,
model_provider_name=provider_name,
model_name=args['model_name'],
model_type='text-generation'
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"Current Text Generation Model is invalid. Please switch to the available model.")
rules = {
k: {
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
return rules
class ModelProviderPaymentCheckoutUrlApi(Resource): class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_name: str): def get(self, provider: str):
if provider_name != 'anthropic': if provider != 'anthropic':
raise ValueError(f'provider name {provider_name} is invalid') raise ValueError(f'provider name {provider} is invalid')
data = BillingService.get_model_provider_payment_link(provider_name=provider_name, data = BillingService.get_model_provider_payment_link(provider_name=provider,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account_id=current_user.id) account_id=current_user.id)
return data return data
@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_name: str): def post(self, provider: str):
provider_service = ProviderService() model_provider_service = ModelProviderService()
result = provider_service.free_quota_submit( result = model_provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider_name=provider_name provider=provider
) )
return result return result
@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_name: str): def get(self, provider: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args') parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() model_provider_service = ModelProviderService()
result = provider_service.free_quota_qualification_verify( result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider_name=provider_name, provider=provider,
token=args['token'] token=args['token']
) )
@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
api.add_resource(ModelProviderModelValidateApi, api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
'/workspaces/current/model-providers/<string:provider_name>/models/validate') api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
api.add_resource(ModelProviderModelUpdateApi, api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
'/workspaces/current/model-providers/<string:provider_name>/models') '<string:icon_type>/<string:lang>')
api.add_resource(PreferredProviderTypeUpdateApi, api.add_resource(PreferredProviderTypeUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type') '/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
api.add_resource(ModelProviderPaymentCheckoutUrlApi, api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url') '/workspaces/current/model-providers/<string:provider>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi, api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit') '/workspaces/current/model-providers/<string:provider>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify') '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')

View File

@ -1,16 +1,17 @@
import logging import logging
from flask_login import current_user from flask_login import current_user
from libs.login import login_required from flask_restful import reqparse, Resource
from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.entities.model_entities import ModelType
from core.model_providers.models.entity.model_params import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError
from models.provider import ProviderType from core.model_runtime.utils.encoders import jsonable_encoder
from services.provider_service import ProviderService from libs.login import login_required
from services.model_provider_service import ModelProviderService
class DefaultModelApi(Resource): class DefaultModelApi(Resource):
@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False, parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
provider_service = ProviderService() model_provider_service = ModelProviderService()
default_model = provider_service.get_default_model_of_model_type( default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=args['model_type'] model_type=args['model_type']
) )
if not default_model: return jsonable_encoder({
return None "data": default_model_entity
})
model_provider = ModelProviderFactory.get_preferred_model_provider(
tenant_id,
default_model.provider_name
)
if not model_provider:
return {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': default_model.provider_name
}
}
provider = model_provider.provider
rst = {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': provider.provider_name,
'provider_type': provider.provider_type
}
}
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
if provider.provider_type == ProviderType.SYSTEM.value:
rst['model_provider']['quota_type'] = provider.quota_type
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
rst['model_provider']['quota_limit'] = provider.quota_limit
rst['model_provider']['quota_used'] = provider.quota_used
return rst
@setup_required @setup_required
@login_required @login_required
@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args['model_settings'] model_settings = args['model_settings']
for model_setting in model_settings: for model_setting in model_settings:
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
raise ValueError('invalid model type')
if 'provider' not in model_setting:
continue
if 'model' not in model_setting:
raise ValueError('invalid model')
try: try:
provider_service.update_default_model_of_model_type( model_provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id, tenant_id=tenant_id,
model_type=model_setting['model_type'], model_type=model_setting['model_type'],
provider_name=model_setting['provider_name'], provider=model_setting['provider'],
model_name=model_setting['model_name'] model=model_setting['model']
) )
except Exception: except Exception:
logging.warning(f"{model_setting['model_type']} save error") logging.warning(f"{model_setting['model_type']} save error")
@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
return {'result': 'success'} return {'result': 'success'}
class ValidModelApi(Resource): class ModelProviderModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(
tenant_id=tenant_id,
provider=provider
)
return jsonable_encoder({
"data": models
})
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.save_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {'result': 'success'}, 204
class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_model_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=args['model_type'],
model=args['model']
)
return {
"credentials": credentials
}
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
result = True
error = None
try:
model_provider_service.model_credentials_validate(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
tenant_id=tenant_id,
provider=provider,
model=args['model']
)
return jsonable_encoder({
"data": parameter_rules
})
class ModelProviderAvailableModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, model_type): def get(self, model_type):
ModelType.value_of(model_type) tenant_id = current_user.current_tenant_id
provider_service = ProviderService() model_provider_service = ModelProviderService()
valid_models = provider_service.get_valid_model_list( models = model_provider_service.get_models_by_model_type(
tenant_id=current_user.current_tenant_id, tenant_id=tenant_id,
model_type=model_type model_type=model_type
) )
return valid_models return jsonable_encoder({
"data": models
})
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
api.add_resource(ModelProviderModelCredentialApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
api.add_resource(DefaultModelApi, '/workspaces/current/default-model') api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

View File

@ -1,131 +0,0 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
provider_service = ProviderService()
provider_info_list = provider_service.get_provider_list(tenant_id)
provider_list = [
{
'provider_name': p['provider_name'],
'provider_type': p['provider_type'],
'is_valid': p['is_valid'],
'last_used': p['last_used'],
'is_enabled': p['is_valid'],
**({
'quota_type': p['quota_type'],
'quota_limit': p['quota_limit'],
'quota_used': p['quota_used']
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
if p['config'] else None
}
for name, provider_info in provider_info_list.items()
for p in provider_info['providers']
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list

View File

@ -34,7 +34,6 @@ tenant_fields = {
'status': fields.String, 'status': fields.String,
'created_at': TimestampField, 'created_at': TimestampField,
'role': fields.String, 'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)),
'in_trial': fields.Boolean, 'in_trial': fields.Boolean,
'trial_end_reason': fields.String, 'trial_end_reason': fields.String,
'custom_config': fields.Raw(attribute='custom_config'), 'custom_config': fields.Raw(attribute='custom_config'),

View File

@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig from models.model import App, AppModelConfig
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e

View File

@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \ ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError ProviderModelCurrentlyNotSupportError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from core.conversation_message_task import PubHandler from core.application_queue_manager import ApplicationQueueManager
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ from core.entities.application_entities import InvokeFrom
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
app_model=app_model, app_model=app_model,
user=end_user, user=end_user,
args=args, args=args,
from_source='api', invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming, streaming=streaming,
) )
@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
PubHandler.stop(end_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
app_model=app_model, app_model=app_model,
user=end_user, user=end_user,
args=args, args=args,
from_source='api', invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming streaming=streaming
) )
@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
PubHandler.stop(end_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@ -4,11 +4,11 @@ import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import current_user from libs.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name): def _validate_name(name):
@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
datasets, total = DatasetService.get_datasets(page, limit, provider, datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user) tenant_id, current_user)
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_manager = ProviderManager()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, configurations = provider_manager.get_configurations(
ModelType.EMBEDDINGS.value) tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = [] model_names = []
for valid_model in valid_model_list: for embedding_model in embedding_models:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields) data = marshal(datasets, dataset_detail_fields)
for item in data: for item in data:
if item['indexing_technique'] == 'high_quality': if item['indexing_technique'] == 'high_quality':

View File

@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
NoFileUploadedError, TooManyFilesError NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from libs.login import current_user from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment

View File

@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from fields.segment_fields import segment_fields from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
ModelFactory.get_embedding_model( model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(

View File

@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e

View File

@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \ ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler from core.application_queue_manager import ApplicationQueueManager
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.entities.application_entities import InvokeFrom
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
app_model=app_model, app_model=app_model,
user=end_user, user=end_user,
args=args, args=args,
from_source='api', invoke_from=InvokeFrom.WEB_APP,
streaming=streaming streaming=streaming
) )
@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
PubHandler.stop(end_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
app_model=app_model, app_model=app_model,
user=end_user, user=end_user,
args=args, args=args,
from_source='api', invoke_from=InvokeFrom.WEB_APP,
streaming=streaming streaming=streaming
) )
@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
PubHandler.stop(end_user, task_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.entities.application_entities import InvokeFrom
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
try: try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app') response = CompletionService.generate_more_like_this(
app_model=app_model,
user=end_user,
message_id=message_id,
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
return compact_response(response) return compact_response(response)
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except ValueError as e: except ValueError as e:
raise e raise e
@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e: except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, except InvokeError as e:
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e)) raise CompletionRequestError(str(e))
except Exception: except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -0,0 +1,101 @@
import logging
from typing import Optional, List
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class AgentLLMCallback(Callback):
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
self.agent_callback = agent_callback
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_before_invoke(
prompt_messages=prompt_messages
)
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
pass
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_after_invoke(
result=result
)
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_error(
error=ex
)

View File

@ -1,28 +1,49 @@
from typing import List from typing import List, cast
from langchain.schema import BaseMessage from langchain.schema import BaseMessage
from core.model_providers.models.entity.message import to_prompt_messages from core.entities.application_entities import ModelConfigEntity
from core.model_providers.models.llm.base import BaseLLM from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin: class CalcTokenMixin:
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
""" """
Got the rest tokens available for the model after excluding messages tokens and completion max tokens Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm: :param model_config:
:param messages: :param messages:
:return: :return:
""" """
llm_max_tokens = model_instance.model_rules.max_tokens.max model_type_instance = model_config.provider_model_bundle.model_type_instance
completion_max_tokens = model_instance.model_kwargs.max_tokens model_type_instance = cast(LargeLanguageModel, model_type_instance)
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens return rest_tokens

View File

@ -1,4 +1,3 @@
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages from core.entities.application_entities import ModelConfigEntity
from core.model_providers.models.llm.base import BaseLLM from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
""" """
An Multi Dataset Retrieve Agent driven by Router. An Multi Dataset Retrieve Agent driven by Router.
""" """
model_instance: BaseLLM model_config: ModelConfigEntity
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
agent_decision.return_values['output'] = '' agent_decision.return_values['output'] = ''
return agent_decision return agent_decision
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e) raise e
raise new_exception
def real_plan( def real_plan(
self, self,
@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs) prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages() messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages) prompt_messages = lc_messages_to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages, model_instance = ModelInstance(
functions=self.functions, provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
) )
ai_message = AIMessage( ai_message = AIMessage(
content=result.content, content=result.message.content or "",
additional_kwargs={ additional_kwargs={
'function_call': result.function_call 'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
} }
) )
@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_instance: BaseLLM, model_config: ModelConfigEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
system_message=system_message, system_message=system_message,
) )
return cls( return cls(
model_instance=model_instance, model_config=model_config,
llm=FakeLLM(response=''), llm=FakeLLM(response=''),
prompt=prompt, prompt=prompt,
tools=tools, tools=tools,

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Any, Union, Sequence, Optional from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages from core.entities.application_entities import ModelConfigEntity
from core.model_providers.models.llm.base import BaseLLM from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_model_instance: BaseLLM = None summary_model_config: ModelConfigEntity = None
model_instance: BaseLLM model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_instance: BaseLLM, model_config: ModelConfigEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
prompt = cls.create_prompt( prompt = cls.create_prompt(
@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
system_message=system_message, system_message=system_message,
) )
return cls( return cls(
model_instance=model_instance, model_config=model_config,
llm=FakeLLM(response=''), llm=FakeLLM(response=''),
prompt=prompt, prompt=prompt,
tools=tools, tools=tools,
callback_manager=callback_manager, callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs, **kwargs,
) )
@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
:param query: :param query:
:return: :return:
""" """
original_max_tokens = self.model_instance.model_kwargs.max_tokens original_max_tokens = 0
self.model_instance.model_kwargs.max_tokens = 40 for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages() messages = prompt.to_messages()
try: try:
prompt_messages = to_prompt_messages(messages) prompt_messages = lc_messages_to_prompt_messages(messages)
result = self.model_instance.run( model_instance = ModelInstance(
messages=prompt_messages, provider_model_bundle=self.model_config.provider_model_bundle,
functions=self.functions, model=self.model_config.model,
callbacks=None )
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
) )
except Exception as e: except Exception as e:
new_exception = self.model_instance.handle_exceptions(e) raise e
raise new_exception
function_call = result.function_call self.model_config.parameters['max_tokens'] = original_max_tokens
self.model_instance.model_kwargs.max_tokens = original_max_tokens return True if result.message.tool_calls else False
return True if function_call else False
def plan( def plan(
self, self,
@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
prompt = self.prompt.format_prompt(**full_inputs) prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages() messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0 # summarize messages if rest_tokens < 0
try: try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions) prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e: except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e)) return AgentFinish(return_values={"output": str(e)}, log=str(e))
prompt_messages = to_prompt_messages(messages) model_instance = ModelInstance(
result = self.model_instance.run( provider_model_bundle=self.model_config.provider_model_bundle,
messages=prompt_messages, model=self.model_config.model,
functions=self.functions, )
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
) )
ai_message = AIMessage( ai_message = AIMessage(
content=result.content, content=result.message.content or "",
additional_kwargs={ additional_kwargs={
'function_call': result.function_call 'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
} }
) )
agent_decision = _parse_ai_message(ai_message) agent_decision = _parse_ai_message(ai_message)
@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
except ValueError: except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0: if rest_tokens >= 0:
return messages return messages
@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
ai_prefix="AI", ai_prefix="AI",
) )
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines) return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/ Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai': if model_config.provider == 'azure_openai':
model = model_instance.base_model_name model = model_config.model
model = model.replace("gpt-35", "gpt-3.5") model = model.replace("gpt-35", "gpt-3.5")
else: else:
model = model_instance.base_model_name model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken() tiktoken_ = _import_tiktoken()
try: try:

View File

@ -1,158 +0,0 @@
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@ -12,9 +12,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode from core.entities.application_entities import ModelConfigEntity
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return True return True
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
try: try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e: except Exception as e:
new_exception = self.llm_chain.model_instance.handle_exceptions(e) raise e
raise new_exception
try: try:
agent_decision = self.output_parser.parse(full_output) agent_decision = self.output_parser.parse(full_output)
@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
except OutputParserException: except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")
@classmethod @classmethod
def create_prompt( def create_prompt(
cls, cls,
@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables) return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad( def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str: ) -> str:
agent_scratchpad = "" agent_scratchpad = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
raise ValueError("agent_scratchpad should be of type string.") raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad: if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain) llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT: if llm_chain.model_config.mode == "chat":
return ( return (
f"This was your previous work " f"This was your previous work "
f"(but I haven't seen any of it! I only see what " f"(but I haven't seen any of it! I only see what "
@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_instance: BaseLLM, model_config: ModelConfigEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT: if model_config.mode == "chat":
prompt = cls.create_prompt( prompt = cls.create_prompt(
tools, tools,
prefix=prefix, prefix=prefix,
@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables input_variables=input_variables
) )
llm_chain = LLMChain( llm_chain = LLMChain(
model_instance=model_instance, model_config=model_config,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, callback_manager=callback_manager,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
) )
tool_names = [tool.name for tool in tools] tool_names = [tool.name for tool in tools]
_output_parser = output_parser _output_parser = output_parser

View File

@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode from core.entities.application_entities import ModelConfigEntity
from core.model_providers.models.llm.base import BaseLLM from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@ -54,7 +55,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_model_instance: BaseLLM = None summary_model_config: ModelConfigEntity = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observatons
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts: if prompts:
messages = prompts[0].to_messages() messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages) prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0: if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try: try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e: except Exception as e:
new_exception = self.llm_chain.model_instance.handle_exceptions(e) raise e
raise new_exception
try: try:
agent_decision = self.output_parser.parse(full_output) agent_decision = self.output_parser.parse(full_output)
@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_instance: if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation) should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps] for _, observation in should_summary_intermediate_steps]
@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
ai_prefix="AI", ai_prefix="AI",
) )
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines) return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod @classmethod
@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
raise ValueError("agent_scratchpad should be of type string.") raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad: if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain) llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT: if llm_chain.model_config.mode == "chat":
return ( return (
f"This was your previous work " f"This was your previous work "
f"(but I haven't seen any of it! I only see what " f"(but I haven't seen any of it! I only see what "
@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_instance: BaseLLM, model_config: ModelConfigEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT: if model_config.mode == "chat":
prompt = cls.create_prompt( prompt = cls.create_prompt(
tools, tools,
prefix=prefix, prefix=prefix,
@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
input_variables=input_variables, input_variables=input_variables,
) )
llm_chain = LLMChain( llm_chain = LLMChain(
model_instance=model_instance, model_config=model_config,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
) )
tool_names = [tool.name for tool in tools] tool_names = [tool.name for tool in tools]
_output_parser = output_parser _output_parser = output_parser

View File

@ -4,10 +4,10 @@ from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import AgentExecutor as LCAgentExecutor
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation from core.helper import moderation
from core.model_providers.error import LLMError from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_providers.models.llm.base import BaseLLM from core.model_runtime.errors.invoke import InvokeError
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_instance: BaseLLM model_config: ModelConfigEntity
tools: list[BaseTool] tools: list[BaseTool]
summary_model_instance: BaseLLM = None summary_model_config: Optional[ModelConfigEntity] = None
memory: Optional[BaseChatMemory] = None memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
early_stopping_method: str = "generate" early_stopping_method: str = "generate"
agent_llm_callback: Optional[AgentLLMCallback] = None
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config: class Config:
@ -62,34 +65,42 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT: if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_config=self.configuration.model_config,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
summary_model_instance=self.configuration.summary_model_instance summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_instance else None, if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_config=self.configuration.model_config,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
summary_model_instance=self.configuration.summary_model_instance if self.configuration.memory else None, # used for read chat histories memory
if self.configuration.summary_model_instance else None, summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_config=self.configuration.model_config,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_config=self.configuration.model_config,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
verbose=True verbose=True
@ -104,11 +115,11 @@ class AgentExecutor:
def run(self, query: str) -> AgentExecuteResult: def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation( moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider, self.configuration.model_config,
query query
) )
if not moderation_result: if moderation_result:
return AgentExecuteResult( return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy, strategy=self.configuration.strategy,
@ -118,7 +129,6 @@ class AgentExecutor:
agent_executor = LCAgentExecutor.from_agent_and_tools( agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent, agent=self.agent,
tools=self.configuration.tools, tools=self.configuration.tools,
memory=self.configuration.memory,
max_iterations=self.configuration.max_iterations, max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time, max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method, early_stopping_method=self.configuration.early_stopping_method,
@ -126,8 +136,8 @@ class AgentExecutor:
) )
try: try:
output = agent_executor.run(query) output = agent_executor.run(input=query)
except LLMError as ex: except InvokeError as ex:
raise ex raise ex
except Exception as ex: except Exception as ex:
logging.exception("agent_executor run failed") logging.exception("agent_executor run failed")

View File

@ -0,0 +1,251 @@
import json
import logging
from typing import cast
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.app_runner.app_runner import AppRunner
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
from core.application_queue_manager import ApplicationQueueManager
from core.features.agent_runner import AgentRunnerFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
logger = logging.getLogger(__name__)
class AgentApplicationRunner(AppRunner):
"""
Agent Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run agent application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=None,
memory=memory
)
# Create MessageChain
message_chain = self._init_message_chain(
message=message,
query=query
)
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_config=app_orchestration_config.model_config,
message=message,
queue_manager=queue_manager,
message_chain=message_chain
)
# init LLM Callback
agent_llm_callback = AgentLLMCallback(
agent_callback=agent_callback
)
agent_runner = AgentRunnerFeature(
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config=app_orchestration_config,
model_config=app_orchestration_config.model_config,
config=app_orchestration_config.agent,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
agent_llm_callback=agent_llm_callback,
callback=agent_callback,
memory=memory
)
# agent run
result = agent_runner.run(
query=query,
invoke_from=application_generate_entity.invoke_from
)
if result:
self._save_message_chain(
message_chain=message_chain,
output_text=result
)
if (result
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
and app_orchestration_config.prompt_template.simple_prompt_template
):
# Direct output if agent result exists and has pre prompt
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
stream=application_generate_entity.stream,
text=result,
usage=self._get_usage_of_all_agent_thoughts(
model_config=app_orchestration_config.model_config,
message=message
)
)
else:
# As normal LLM run, agent result as context
context = result
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
)

View File

@ -0,0 +1,267 @@
import time
from typing import cast, Optional, List, Tuple, Generator, Union
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform
from models.model import App
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileObj],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query
)
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: List[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
model_config.parameters[parameter_rule.name] = max_tokens
def originze_prompt_messages(self, app_record: App,
model_config: ModelConfigEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileObj],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> Tuple[List[PromptMessage], Optional[List[str]]]:
"""
Organize prompt messages
:param context:
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:param memory: memory
:return:
"""
prompt_transform = PromptTransform()
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=app_record.mode,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
files=files,
context=context,
memory=memory,
model_config=model_config
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=app_record.mode,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
def direct_output(self, queue_manager: ApplicationQueueManager,
app_orchestration_config: AppOrchestrationConfigEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param app_orchestration_config: app orchestration config
:param prompt_messages: prompt messages
:param text: text
:param stream: stream
:param usage: usage
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish_chunk_message(LLMResultChunk(
model=app_orchestration_config.model_config.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=token)
)
))
index += 1
time.sleep(0.01)
queue_manager.publish_message_end(
llm_result=LLMResult(
model=app_orchestration_config.model_config.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()
)
)
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
queue_manager: ApplicationQueueManager,
stream: bool) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param stream: stream
:return:
"""
if not stream:
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager
)
else:
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager
)
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
queue_manager: ApplicationQueueManager) -> None:
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:return:
"""
queue_manager.publish_message_end(
llm_result=invoke_result
)
def _handle_invoke_result_stream(self, invoke_result: Generator,
queue_manager: ApplicationQueueManager) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:return:
"""
model = None
prompt_messages = []
text = ''
usage = None
for result in invoke_result:
queue_manager.publish_chunk_message(result)
text += result.delta.message.content
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
llm_result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
)
queue_manager.publish_message_end(
llm_result=llm_result
)

View File

@ -0,0 +1,363 @@
import logging
from typing import Tuple, Optional
from core.app_runner.app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
from core.application_queue_manager import ApplicationQueueManager
from core.features.annotation_reply import AnnotationReplyFeature
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.external_data_fetch import ExternalDataFetchFeature
from core.features.hosting_moderation import HostingModerationFeature
from core.features.moderation import ModerationFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.moderation.base import ModerationException
from core.prompt.prompt_transform import AppMode
from extensions.ext_database import db
from models.model import Conversation, Message, App, MessageAnnotation
logger = logging.getLogger(__name__)
class BasicApplicationRunner(AppRunner):
"""
Basic Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config_entity=app_orchestration_config,
inputs=inputs,
query=query,
)
except ModerationException as e:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
)
return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id
)
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get context from datasets
context = None
if app_orchestration_config.dataset:
context = self.retrieve_dataset_context(
tenant_id=app_record.tenant_id,
app_record=app_record,
queue_manager=queue_manager,
model_config=app_orchestration_config.model_config,
show_retrieve_source=app_orchestration_config.show_retrieve_source,
dataset_config=app_orchestration_config.dataset,
message=message,
inputs=inputs,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
memory=memory
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)
def moderation_for_inputs(self, app_id: str,
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_orchestration_config_entity: app orchestration config entity
:param inputs: inputs
:param query: query
:return:
"""
moderation_feature = ModerationFeature()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_orchestration_config_entity=app_orchestration_config_entity,
inputs=inputs,
query=query,
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record,
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
)
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
external_data_fetch_feature = ExternalDataFetchFeature()
return external_data_fetch_feature.fetch(
tenant_id=tenant_id,
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
def retrieve_dataset_context(self, tenant_id: str,
app_record: App,
queue_manager: ApplicationQueueManager,
model_config: ModelConfigEntity,
dataset_config: DatasetEntity,
show_retrieve_source: bool,
message: Message,
inputs: dict,
query: str,
user_id: str,
invoke_from: InvokeFrom,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset context
:param tenant_id: tenant id
:param app_record: app record
:param queue_manager: queue manager
:param model_config: model config
:param dataset_config: dataset config
:param show_retrieve_source: show retrieve source
:param message: message
:param inputs: inputs
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:param memory: memory
:return:
"""
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
user_id,
invoke_from
)
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrievalFeature()
return dataset_retrieval.retrieve(
tenant_id=tenant_id,
model_config=model_config,
config=dataset_config,
query=query,
invoke_from=invoke_from,
show_retrieve_source=show_retrieve_source,
hit_callback=hit_callback,
memory=memory
)
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param prompt_messages: prompt messages
:return:
"""
hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity,
prompt_messages=prompt_messages
)
if moderation_result:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream
)
return moderation_result

View File

@ -0,0 +1,483 @@
import json
import logging
import time
from typing import Union, Generator, cast, Optional
from pydantic import BaseModel
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
from core.entities.application_entities import ApplicationGenerateEntity
from core.application_queue_manager import ApplicationQueueManager
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
AnnotationReplyEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from models.model import Message, Conversation, MessageAgentThought
from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
class TaskState(BaseModel):
"""
TaskState entity
"""
llm_result: LLMResult
metadata: dict = {}
class GenerateTaskPipeline:
"""
GenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._conversation = conversation
self._message = message
self._task_state = TaskState(
llm_result=LLMResult(
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage()
)
)
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
def process(self, stream: bool) -> Union[dict, Generator]:
"""
Process generate task pipeline.
:return:
"""
if stream:
return self._process_stream_response()
else:
return self._process_blocking_response()
def _process_blocking_response(self) -> dict:
"""
Process blocking response.
:return:
"""
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
}
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
[self._task_state.llm_result.message]
)
credentials = model_config.credentials
# transform usage
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
)
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
completion=self._task_state.llm_result.message.content,
public_event=False
)
# Save message
self._save_message(event.llm_result)
response = {
'event': 'message',
'task_id': self._application_generate_entity.task_id,
'id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
response['metadata'] = self._task_state.metadata
return response
else:
continue
def _process_stream_response(self) -> Generator:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
event = message.event
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
[self._task_state.llm_result.message]
)
credentials = model_config.credentials
# transform usage
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
)
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
completion=self._task_state.llm_result.message.content,
public_event=False
)
self._output_moderation_handler = None
replace_response = {
'event': 'message_replace',
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'answer': self._task_state.llm_result.message.content,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
replace_response['conversation_id'] = self._conversation.id
yield self._yield_response(replace_response)
# Save message
self._save_message(self._task_state.llm_result)
response = {
'event': 'message_end',
'task_id': self._application_generate_entity.task_id,
'id': self._message.id,
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
response['metadata'] = self._task_state.metadata
yield self._yield_response(response)
elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
}
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
agent_thought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
)
if agent_thought:
response = {
'event': 'agent_thought',
'id': agent_thought.id,
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'position': agent_thought.position,
'thought': agent_thought.thought,
'tool': agent_thought.tool,
'tool_input': agent_thought.tool_input,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
elif isinstance(event, QueueMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
continue
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._queue_manager.publish_chunk_message(LLMResultChunk(
model=self._task_state.llm_result.model,
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
))
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
continue
else:
self._output_moderation_handler.append_new_token(delta_text)
self._task_state.llm_result.message.content += delta_text
response = self._handle_chunk(delta_text)
yield self._yield_response(response)
elif isinstance(event, QueueMessageReplaceEvent):
response = {
'event': 'message_replace',
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'answer': event.text,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
elif isinstance(event, QueuePingEvent):
yield "event: ping\n\n"
else:
continue
def _save_message(self, llm_result: LLMResult) -> None:
"""
Save message.
:param llm_result: llm result
:return:
"""
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
if llm_result.message.content else ''
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
db.session.commit()
message_was_created.send(
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
)
def _handle_chunk(self, text: str) -> dict:
"""
Handle completed event.
:param text: text
:return:
"""
response = {
'event': 'message',
'id': self._message.id,
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'answer': text,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
return response
def _handle_error(self, event: QueueErrorEvent) -> Exception:
"""
Handle error event.
:param event: event
:return:
"""
logger.debug("error: %s", event.error)
e = event.error
if isinstance(e, InvokeAuthorizationError):
return InvokeAuthorizationError('Incorrect API key provided')
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
return e
else:
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
def _yield_response(self, response: dict) -> str:
"""
Yield response.
:param response: response
:return:
"""
return "data: " + json.dumps(response) + "\n\n"
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
"""
Prompt messages to prompt for saving.
:param prompt_messages: prompt messages
:return:
"""
prompts = []
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
else:
continue
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
prompts.append({
"role": role,
"text": text,
"files": files
})
else:
prompts.append({
"role": 'user',
"text": prompt_messages[0].content
})
return prompts
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
"""
Init output moderation.
:return:
"""
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
if sensitive_word_avoidance:
return OutputModerationHandler(
tenant_id=self._application_generate_entity.tenant_id,
app_id=self._application_generate_entity.app_id,
rule=ModerationRule(
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
on_message_replace_func=self._queue_manager.publish_message_replace
)

View File

@ -0,0 +1,138 @@
import logging
import threading
import time
from typing import Any, Optional, Dict
from flask import current_app, Flask
from pydantic import BaseModel
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__)
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logger.error("Moderation Output error: %s", e)
return None

View File

@ -0,0 +1,655 @@
import json
import logging
import threading
import uuid
from typing import cast, Optional, Any, Union, Generator, Tuple
from flask import Flask, current_app
from pydantic import ValidationError
from core.app_runner.agent_app_runner import AgentApplicationRunner
from core.app_runner.basic_app_runner import BasicApplicationRunner
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
from core.entities.model_entities import ModelStatus
from core.file.file_obj import FileObj
from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, Conversation, Message, MessageFile, App
logger = logging.getLogger(__name__)
class ApplicationManager:
"""
This class is responsible for managing application
"""
def generate(self, tenant_id: str,
app_id: str,
app_model_config_id: str,
app_model_config_dict: dict,
app_model_config_override: bool,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
inputs: dict[str, str],
query: Optional[str] = None,
files: Optional[list[FileObj]] = None,
conversation: Optional[Conversation] = None,
stream: bool = False,
extras: Optional[dict[str, Any]] = None) \
-> Union[dict, Generator]:
"""
Generate App response.
:param tenant_id: workspace ID
:param app_id: app ID
:param app_model_config_id: app model config id
:param app_model_config_dict: app model config dict
:param app_model_config_override: app model config override
:param user: account or end user
:param invoke_from: invoke from source
:param inputs: inputs
:param query: query
:param files: file obj list
:param conversation: conversation
:param stream: is stream
:param extras: extras
"""
# init task id
task_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = ApplicationGenerateEntity(
task_id=task_id,
tenant_id=tenant_id,
app_id=app_id,
app_model_config_id=app_model_config_id,
app_model_config_dict=app_model_config_dict,
app_orchestration_config_entity=self._convert_from_app_model_config_dict(
tenant_id=tenant_id,
app_model_config_dict=app_model_config_dict
),
app_model_config_override=app_model_config_override,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else inputs,
query=query.replace('\x00', '') if query else None,
files=files if files else [],
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = ApplicationQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if application_generate_entity.app_orchestration_config_entity.agent:
# agent app
runner = AgentApplicationRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
else:
# basic app
runner = BasicApplicationRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message,
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = GenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
try:
return generate_task_pipeline.process(stream=stream)
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
"""
Convert app model config dict to entity.
:param tenant_id: tenant ID
:param app_model_config_dict: app model config dict
:raises ProviderTokenNotInitError: provider token not init error
:return: app orchestration config entity
"""
properties = {}
copy_app_model_config_dict = app_model_config_dict.copy()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=copy_app_model_config_dict['model']['provider'],
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
model_name = copy_app_model_config_dict['model']['name']
model_type_instance = provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=copy_app_model_config_dict['model']['name']
)
if model_credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=copy_app_model_config_dict['model']['name'],
model_type=ModelType.LLM
)
if provider_model is None:
model_name = copy_app_model_config_dict['model']['name']
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = copy_app_model_config_dict['model'].get('completion_params')
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = copy_app_model_config_dict['model'].get('mode')
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=copy_app_model_config_dict['model']['name'],
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
copy_app_model_config_dict['model']['name'],
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
properties['model_config'] = ModelConfigEntity(
provider=copy_app_model_config_dict['model']['provider'],
model=copy_app_model_config_dict['model']['name'],
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
# prompt template
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
**completion_prompt_template_params
)
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
# external data variables
properties['external_data_variables'] = []
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
properties['external_data_variables'].append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# show retrieve source
show_retrieve_source = False
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
if retriever_resource_dict:
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
show_retrieve_source = True
properties['show_retrieve_source'] = show_retrieve_source
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
'enabled']:
agent_dict = copy_app_model_config_dict.get('agent_mode')
if agent_dict['strategy'] in ['router', 'react_router']:
dataset_ids = []
for tool in agent_dict.get('tools', []):
key = list(tool.keys())[0]
if key != 'dataset':
continue
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
single_strategy=agent_dict['strategy']
)
)
else:
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
)
)
else:
if agent_dict['strategy'] == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
strategy = AgentEntity.Strategy.FUNCTION_CALLING
agent_tools = []
for tool in agent_dict.get('tools', []):
key = list(tool.keys())[0]
tool_item = tool[key]
agent_tool_properties = {
"tool_id": key
}
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
agent_tool_properties["config"] = tool_item
agent_tools.append(AgentToolEntity(**agent_tool_properties))
properties['agent'] = AgentEntity(
provider=properties['model_config'].provider,
model=properties['model_config'].model,
strategy=strategy,
tools=agent_tools
)
# file upload
file_upload_dict = copy_app_model_config_dict.get('file_upload')
if file_upload_dict:
if 'image' in file_upload_dict and file_upload_dict['image']:
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
properties['file_upload'] = FileUploadEntity(
image_config={
'number_limits': file_upload_dict['image']['number_limits'],
'detail': file_upload_dict['image']['detail'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
)
# opening statement
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
# suggested questions after answer
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
properties['suggested_questions_after_answer'] = True
# more like this
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
# speech to text
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
if speech_to_text_dict:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
properties['speech_to_text'] = True
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
if sensitive_word_avoidance_dict:
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
return AppOrchestrationConfigEntity(**properties)
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
-> Tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
:return:
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_schema = model_type_instance.get_model_schema(
model=app_orchestration_config_entity.model_config.model,
credentials=app_orchestration_config_entity.model_config.credentials
)
app_record = (db.session.query(App)
.filter(App.id == application_generate_entity.app_id).first())
app_mode = app_record.mode
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api'
end_user_id = application_generate_entity.user_id
else:
from_source = 'console'
account_id = application_generate_entity.user_id
override_model_configs = None
if application_generate_entity.app_model_config_override:
override_model_configs = application_generate_entity.app_model_config_dict
introduction = ''
if app_mode == 'chat':
# get conversation introduction
introduction = self._get_conversation_introduction(application_generate_entity)
if not application_generate_entity.conversation_id:
conversation = Conversation(
app_id=app_record.id,
app_model_config_id=application_generate_entity.app_model_config_id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_mode,
name='New conversation',
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status='normal',
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.commit()
else:
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == application_generate_entity.conversation_id,
Conversation.app_id == app_record.id
).first()
)
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
message = Message(
app_id=app_record.id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=currency,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
agent_based=app_orchestration_config_entity.agent is not None
)
db.session.add(message)
db.session.commit()
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if account_id else 'end_user'),
created_by=account_id or end_user_id,
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
"""
Get conversation introduction
:param application_generate_entity: application generate entity
:return: conversation introduction
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
introduction = app_orchestration_config_entity.opening_statement
if introduction:
try:
inputs = application_generate_entity.inputs
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
:return: conversation
"""
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
return conversation
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
:return: message
"""
message = (
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
return message

View File

@ -0,0 +1,228 @@
import queue
import time
from typing import Generator, Any
from sqlalchemy.orm import DeclarativeMeta
from core.entities.application_entities import InvokeFrom
from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
QueueMessageEvent, QueueMessage, AnnotationReplyEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from extensions.ext_redis import redis_client
from models.model import MessageAgentThought
class ApplicationQueueManager:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
q = queue.Queue()
self._q = q
def listen(self) -> Generator:
"""
Listen to queue
:return:
"""
# wait for 10 minutes to stop listen
listen_timeout = 600
start_time = time.time()
last_ping_time = 0
while True:
try:
message = self._q.get(timeout=1)
if message is None:
break
yield message
except queue.Empty:
continue
finally:
elapsed_time = time.time() - start_time
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
self.stop_listen()
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent())
last_ping_time = elapsed_time // 10
def stop_listen(self) -> None:
"""
Stop listen to queue
:return:
"""
self._q.put(None)
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
"""
Publish chunk message to channel
:param chunk: chunk
:return:
"""
self.publish(QueueMessageEvent(
chunk=chunk
))
def publish_message_replace(self, text: str) -> None:
"""
Publish message replace
:param text: text
:return:
"""
self.publish(QueueMessageReplaceEvent(
text=text
))
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
"""
Publish retriever resources
:return:
"""
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
def publish_annotation_reply(self, message_annotation_id: str) -> None:
"""
Publish annotation reply
:param message_annotation_id: message annotation id
:return:
"""
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
def publish_message_end(self, llm_result: LLMResult) -> None:
"""
Publish message end
:param llm_result: llm result
:return:
"""
self.publish(QueueMessageEndEvent(llm_result=llm_result))
self.stop_listen()
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
"""
Publish agent thought
:param message_agent_thought: message agent thought
:return:
"""
self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id
))
def publish_error(self, e) -> None:
"""
Publish error
:param e: error
:return:
"""
self.publish(QueueErrorEvent(
error=e
))
self.stop_listen()
def publish(self, event: AppQueueEvent) -> None:
"""
Publish event to queue
:param event:
:return:
"""
self._check_for_sqlalchemy_models(event.dict())
message = QueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message)
if isinstance(event, QueueStopEvent):
self.stop_listen()
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
"""
Set task stop flag
:return:
"""
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped
:return:
"""
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key)
if result is not None:
redis_client.delete(stopped_cache_key)
return True
return False
@classmethod
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
"""
Generate task belong cache key
:param task_id: task id
:return:
"""
return f"generate_task_belong:{task_id}"
@classmethod
def _generate_stopped_cache_key(cls, task_id: str) -> str:
"""
Generate stopped cache key
:param task_id: task id
:return:
"""
return f"generate_task_stopped:{task_id}"
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed.")
class ConversationTaskStoppedException(Exception):
pass

View File

@ -2,30 +2,40 @@ import json
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional, cast
from langchain.agents import openai_functions_agent, openai_functions_multi_agent from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.entities.application_entities import ModelConfigEntity
from core.model_providers.models.entity.message import PromptMessage from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_providers.models.llm.base import BaseLLM from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from models.model import MessageChain, MessageAgentThought, Message
class AgentLoopGatherCallbackHandler(BaseCallbackHandler): class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True raise_error: bool = True
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, model_config: ModelConfigEntity,
queue_manager: ApplicationQueueManager,
message: Message,
message_chain: MessageChain) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
self.model_instance = model_instance self.model_config = model_config
self.conversation_message_task = conversation_message_task self.queue_manager = queue_manager
self.message = message
self.message_chain = message_chain
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None self._message_agent_thought = None
self.current_chain = None
@property @property
def agent_loops(self) -> List[AgentLoop]: def agent_loops(self) -> List[AgentLoop]:
@ -46,65 +56,60 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks.""" """Whether to ignore chain callbacks."""
return True return True
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
if result.usage:
self._current_loop.prompt_tokens = result.usage.prompt_tokens
else:
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
model=self.model_config.model,
credentials=self.model_config.credentials,
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
)
completion_message = result.message
if completion_message.tool_calls:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.tool_calls})
else:
self._current_loop.completion = completion_message.content
if result.usage:
self._current_loop.completion_tokens = result.usage.completion_tokens
else:
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
model=self.model_config.model,
credentials=self.model_config.credentials,
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
)
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
if not self._current_loop: pass
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Print out the prompts.""" pass
# serialized={'name': 'OpenAI'}
# prompts=['Answer the following questions...\nThought:']
# kwargs={}
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt=prompts[0],
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing.""" """Do nothing."""
# kwargs={} pass
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
else:
self._current_loop.completion = completion_generation.text
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if completion is not None: if completion is not None:
self._current_loop.completion = completion self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start( self._message_agent_thought = self._init_agent_thought()
self.current_chain,
self._current_loop
)
def on_tool_end( def on_tool_end(
self, self,
@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter() self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end( self._complete_agent_thought(self._message_agent_thought)
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
self._current_loop = None self._current_loop = None
@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter() self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]' self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start( self._message_agent_thought = self._init_agent_thought()
self.current_chain,
self._current_loop
)
self.conversation_message_task.on_agent_end( self._complete_agent_thought(self._message_agent_thought)
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
self._current_loop = None self._current_loop = None
self._message_agent_thought = None self._message_agent_thought = None
elif not self._current_loop and self._agent_loops: elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish' self._agent_loops[-1].status = 'agent_finish'
def _init_agent_thought(self) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=self.message_chain.id,
position=self._current_loop.position,
thought=self._current_loop.thought,
tool=self._current_loop.tool_name,
tool_input=self._current_loop.tool_input,
message=self._current_loop.prompt,
message_price_unit=0,
answer=self._current_loop.completion,
answer_price_unit=0,
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
created_by=(self.message.from_account_id
if self.message.from_source == 'console' else self.message.from_end_user_id)
)
db.session.add(message_agent_thought)
db.session.commit()
self.queue_manager.publish_agent_thought(message_agent_thought)
return message_agent_thought
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
loop_message_tokens = self._current_loop.prompt_tokens
loop_answer_tokens = self._current_loop.completion_tokens
# transform usage
llm_usage = self.model_type_instance._calc_response_usage(
self.model_config.model,
self.model_config.credentials,
loop_message_tokens,
loop_answer_tokens
)
message_agent_thought.observation = self._current_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
message_agent_thought.latency = self._current_loop.latency
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
message_agent_thought.total_price = llm_usage.total_price
message_agent_thought.currency = llm_usage.currency
db.session.commit()

View File

@ -1,74 +0,0 @@
import json
import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.queries = []
self.conversation_message_task = conversation_message_task
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
tool_name: str = serialized.get('name')
dataset_id = tool_name.removeprefix('dataset-')
try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.debug("Dataset tool on_llm_error: %s", error)

View File

@ -1,16 +0,0 @@
from pydantic import BaseModel
class ChainResult(BaseModel):
type: str = None
prompt: dict = None
completion: dict = None
status: str = 'chain_started'
completed: bool = False
started_at: float = None
completed_at: float = None
agent_result: dict = None
"""only when type is 'AgentExecutor'"""

View File

@ -1,6 +0,0 @@
from pydantic import BaseModel
class DatasetQueryObj(BaseModel):
dataset_id: str = None
query: str = None

View File

@ -1,8 +0,0 @@
from pydantic import BaseModel
class LLMMessage(BaseModel):
prompt: str = ''
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0

View File

@ -1,17 +1,44 @@
from typing import List from typing import List, Union
from langchain.schema import Document from langchain.schema import Document
from core.conversation_message_task import ConversationMessageTask from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment from models.dataset import DocumentSegment, DatasetQuery
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, queue_manager: ApplicationQueueManager,
self.conversation_message_task = conversation_message_task app_id: str,
message_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
self._user_id = user_id
self._invoke_from = invoke_from
def on_query(self, query: str, dataset_id: str) -> None:
"""
Handle query.
"""
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=self._app_id,
created_by_role=('account'
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self._user_id
)
db.session.add(dataset_query)
db.session.commit()
def on_tool_end(self, documents: List[Document]) -> None: def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end.""" """Handle tool end."""
@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: List): def return_retriever_resource_info(self, resource: List):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
self.conversation_message_task.on_dataset_query_finish(resource) if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self._user_id
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish_retriever_resources(resource)

View File

@ -1,284 +0,0 @@
import logging
import threading
import time
from typing import Any, Dict, List, Union, Optional
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
ImagePromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask):
self.model_instance = model_instance
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
real_prompts = []
for message in messages[0]:
if message.type == 'human':
role = 'user'
elif message.type == 'ai':
role = 'assistant'
else:
role = 'system'
real_prompts.append({
"role": role,
"text": message.content,
"files": [{
"type": file.type.value,
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
return None

View File

@ -1,76 +0,0 @@
import logging
import time
from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self._current_chain_result = None
self._current_chain_message = None
self.conversation_message_task = conversation_message_task
self.agent_callback = None
def clear_chain_results(self) -> None:
self._current_chain_result = None
self._current_chain_message = None
if self.agent_callback:
self.agent_callback.current_chain = None
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return True
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
if not self._current_chain_result:
chain_type = serialized['id'][-1]
if chain_type:
self._current_chain_result = ChainResult(
type=chain_type,
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
self._current_chain_result.status = 'chain_ended'
self._current_chain_result.completion = outputs
self._current_chain_result.completed = True
self._current_chain_result.completed_at = time.perf_counter()
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
self.clear_chain_results()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.debug("Dataset tool on_chain_error: %s", error)
self.clear_chain_results()

View File

@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent action.""" """Run on agent action."""
tool = action.tool tool = action.tool
tool_input = action.tool_input tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 try:
thought = action.log[:action_name_position].strip() if action.log else '' action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
except ValueError:
thought = ''
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
print_text("\n[on_agent_action]\n" + log + "\n", color='green') print_text("\n[on_agent_action]\n" + log + "\n", color='green')

View File

@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from core.model_providers.models.entity.message import to_prompt_messages from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.model_providers.models.llm.base import BaseLLM from core.entities.application_entities import ModelConfigEntity
from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain): class LLMChain(LCLLMChain):
model_instance: BaseLLM model_config: ModelConfigEntity
"""The language model instance to use.""" """The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="") llm: BaseLanguageModel = FakeLLM(response="")
parameters: Dict[str, Any] = {}
agent_llm_callback: Optional[AgentLLMCallback] = None
def generate( def generate(
self, self,
@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages() messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages) prompt_messages = lc_messages_to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages, model_instance = ModelInstance(
stop=stop provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
stream=False,
stop=stop,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
model_parameters=self.parameters
) )
generations = [ generations = [
[Generation(text=result.content)] [Generation(text=result.message.content)]
] ]
return LLMResult(generations=generations) return LLMResult(generations=generations)

View File

@ -1,501 +0,0 @@
import concurrent
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Union, Tuple
from flask import current_app, Flask
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.embedding.cached_embedding import CacheEmbedding
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.dataset import Dataset
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True, from_source: str = 'console'):
"""
errors: ProviderTokenNotInitError
"""
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
# get memory of conversation (read-only)
memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
conversation=conversation,
return_messages=False
)
inputs = conversation.inputs
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
conversation=conversation,
is_override=is_override,
inputs=inputs,
query=query,
files=files,
streaming=streaming,
model_instance=final_model_instance,
auto_generate_name=auto_generate_name
)
prompt_message_files = [file.prompt_message_file for file in files]
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files
)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=app.tenant_id,
app_model_config=app_model_config
)
try:
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
try:
# process sensitive_word_avoidance
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
except ModerationException as e:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=str(e)
)
return
# check annotation reply
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
if annotation_reply:
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
inputs = cls.fill_in_inputs_from_external_data_tools(
tenant_id=app.tenant_id,
app_id=app.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
tenant_id=app.tenant_id,
retriever_from=retriever_from
)
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
# run agent executor
agent_execute_result = None
if query_for_agent and agent_executor:
should_use_agent = agent_executor.should_use_agent(query_for_agent)
if should_use_agent:
agent_execute_result = agent_executor.run(query_for_agent)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=fake_response
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id,
app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return inputs, query
@classmethod
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
inputs: dict, query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
inputs, query
)
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_variable, result = future.result()
results[tool_variable] = result
inputs.update(results)
return inputs
@classmethod
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
with flask_app.app_context():
tool_variable = external_data_tool.get("variable")
tool_type = external_data_tool.get("type")
tool_config = external_data_tool.get("config")
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
return tool_variable, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
files: List[PromptMessageFile],
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()
# get llm prompt
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = prompt_transform.get_prompt(
app_mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
prompt_messages=prompt_messages,
)
response = model_instance.run(
messages=prompt_messages,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
return response
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
from_source: str) -> bool:
"""Get memory messages."""
app_model_config = conversation_message_task.app_model_config
app = conversation_message_task.app
annotation_reply = app_model_config.annotation_reply_dict
if annotation_reply['enabled']:
try:
score_threshold = annotation_reply.get('score_threshold', 1)
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=app.tenant_id,
model_provider_name=embedding_provider_name,
model_name=embedding_model_name
)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search(
conversation_message_task.query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app.id,
annotation.question,
annotation.content,
conversation_message_task.query,
conversation_message_task.user.id,
conversation_message_task.message.id,
from_source,
score)
return True
except Exception as e:
logging.warning(f'Query annotation failed, exception: {str(e)}.')
return False
return False
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
# only for calc token in memory
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=tenant_id,
model_config=app_model_config.model_dict
)
# use llm config from conversation
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
conversation=conversation,
model_instance=memory_model_instance,
max_token_limit=kwargs.get("max_token_limit", 2048),
memory_key=kwargs.get("memory_key", "chat_history"),
return_messages=kwargs.get("return_messages", True),
input_key=kwargs.get("input_key", "input"),
output_key=kwargs.get("output_key", "output"),
message_limit=kwargs.get("message_limit", 10),
)
return memory
@classmethod
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
if model_limited_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_transform = PromptTransform()
# get prompt without memory and context
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
app_mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance
)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
if model_limited_tokens is None:
return
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
# update model instance max tokens
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)

View File

@ -1,517 +0,0 @@
import json
import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.file.file_obj import FileObj
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource, MessageFile
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, files: List[FileObj], streaming: bool,
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
auto_generate_name: bool = True):
self.start_at = time.perf_counter()
self.task_id = task_id
self.app = app
self.tenant_id = app.tenant_id
self.app_model_config = app_model_config
self.is_override = is_override
self.user = user
self.inputs = inputs
self.query = query
self.files = files
self.streaming = streaming
self.conversation = conversation
self.is_new_conversation = False
self.model_instance = model_instance
self.message = None
self.retriever_resource = None
self.auto_generate_name = auto_generate_name
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
self.mode = app.mode
self.init()
self._pub_handler = PubHandler(
user=self.user,
task_id=self.task_id,
message=self.message,
conversation=self.conversation,
chain_pub=False, # disabled currently
agent_thought_pub=True
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = self.app_model_config.to_dict()
introduction = ''
system_instruction = ''
system_instruction_tokens = 0
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app.id,
app_model_config_id=self.app_model_config.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
name='New conversation',
inputs=self.inputs,
introduction=introduction,
system_instruction=system_instruction,
system_instruction_tokens=system_instruction_tokens,
status='normal',
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
)
db.session.add(self.conversation)
db.session.commit()
self.message = Message(
app_id=self.app.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
inputs=self.inputs,
query=self.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
)
db.session.add(self.message)
db.session.commit()
for file in self.files:
message_file = MessageFile(
message_id=self.message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_file)
db.session.commit()
def append_message_text(self, text: str):
if text is not None:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price
db.session.commit()
message_was_created.send(
self.message,
conversation=self.conversation,
is_first_message=self.is_new_conversation,
auto_generate_name=self.auto_generate_name
)
if not by_stopped:
self.end()
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
type=chain_result.type,
input=json.dumps(chain_result.prompt),
output=''
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
message_chain.output = json.dumps(chain_result.completion)
db.session.commit()
self._pub_handler.pub_chain(message_chain)
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
message_price_unit=0,
answer=agent_loop.completion,
answer_price_unit=0,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_thought)
db.session.commit()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.message_price_unit = agent_message_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.answer_price_unit = agent_answer_price_unit
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instance.get_currency()
db.session.commit()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id,
content=dataset_query_obj.query,
source='app',
source_app_id=self.app.id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(dataset_query)
db.session.commit()
def on_dataset_query_finish(self, resource: List):
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self.message.id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self.user.id
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self.retriever_resource = resource
def on_message_replace(self, text: str):
if text is not None:
self._pub_handler.pub_message_replace(text)
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
def end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
self._pub_handler.pub_end()
class PubHandler:
def __init__(self, user: Union[Account, EndUser], task_id: str,
message: Message, conversation: Conversation,
chain_pub: bool = False, agent_thought_pub: bool = False):
self._channel = PubHandler.generate_channel_name(user, task_id)
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
self._task_id = task_id
self._message = message
self._conversation = conversation
self._chain_pub = chain_pub
self._agent_thought_pub = agent_thought_pub
@classmethod
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
if not user:
raise ValueError("user is required")
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result:{}-{}".format(user_str, task_id)
@classmethod
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str):
content = {
'event': 'message',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_replace(self, text: str):
content = {
'event': 'message_replace',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_chain(self, message_chain: MessageChain):
if self._chain_pub:
content = {
'event': 'chain',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_chain.id,
'type': message_chain.type,
'input': json.loads(message_chain.input),
'output': json.loads(message_chain.output),
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
if self._agent_thought_pub:
content = {
'event': 'agent_thought',
'data': {
'id': message_agent_thought.id,
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id,
'position': message_agent_thought.position,
'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_end(self, retriever_resource: List):
content = {
'event': 'message_end',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
}
}
if retriever_resource:
content['data']['retriever_resources'] = retriever_resource
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
content = {
'event': 'annotation',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
'text': text,
'annotation_id': annotation_id,
'annotation_author_name': annotation_author_name
}
}
self._message.answer = text
self._message.provider_response_latency = time.perf_counter() - start_at
db.session.commit()
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',
}
redis_client.publish(self._channel, json.dumps(content))
@classmethod
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
content = {
'error': type(e).__name__,
'description': e.description if getattr(e, 'description', None) is not None else str(e)
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
def _is_stopped(self):
return redis_client.get(self._stopped_cache_key) is not None
@classmethod
def ping(cls, user: Union[Account, EndUser], task_id: str):
content = {
'event': 'ping'
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
@classmethod
def stop(cls, user: Union[Account, EndUser], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
redis_client.setex(stopped_cache_key, 600, 1)
class ConversationTaskStoppedException(Exception):
pass
class ConversationTaskInterruptException(Exception):
pass

View File

@ -1,9 +1,11 @@
from typing import Any, Dict, Optional, Sequence from typing import Any, Dict, Optional, Sequence, cast
from langchain.schema import Document from langchain.schema import Document
from sqlalchemy import func from sqlalchemy import func
from core.model_providers.model_factory import ModelFactory from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
@ -69,10 +71,12 @@ class DatasetDocumentStore:
max_position = 0 max_position = 0
embedding_model = None embedding_model = None
if self._dataset.indexing_technique == 'high_quality': if self._dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id, tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider, provider=self._dataset.embedding_model_provider,
model_name=self._dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
) )
for doc in docs: for doc in docs:
@ -89,7 +93,16 @@ class DatasetDocumentStore:
) )
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0 if embedding_model:
model_type_instance = embedding_model.model_type_instance
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[doc.page_content]
)
else:
tokens = 0
if not segment_document: if not segment_document:
max_position += 1 max_position += 1

View File

@ -1,19 +1,22 @@
import logging import logging
from typing import List from typing import List, Optional
import numpy as np import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.model_providers.models.embedding.base import BaseEmbedding from core.model_manager import ModelInstance
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.dataset import Embedding from models.dataset import Embedding
logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings): class CacheEmbedding(Embeddings):
def __init__(self, embeddings: BaseEmbedding): def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
self._embeddings = embeddings self._model_instance = model_instance
self._user = user
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs."""
@ -22,7 +25,7 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = [] embedding_queue_indices = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
if embedding: if embedding:
text_embeddings[i] = embedding.get_embedding() text_embeddings[i] = embedding.get_embedding()
else: else:
@ -30,15 +33,21 @@ class CacheEmbedding(Embeddings):
if embedding_queue_indices: if embedding_queue_indices:
try: try:
embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices]) embedding_result = self._model_instance.invoke_text_embedding(
texts=[texts[i] for i in embedding_queue_indices],
user=self._user
)
embedding_results = embedding_result.embeddings
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) logger.error('Failed to embed documents: ', ex)
raise ex
for i, indice in enumerate(embedding_queue_indices): for i, indice in enumerate(embedding_queue_indices):
hash = helper.generate_text_hash(texts[indice]) hash = helper.generate_text_hash(texts[indice])
try: try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash) embedding = Embedding(model_name=self._model_instance.model, hash=hash)
vector = embedding_results[i] vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
text_embeddings[indice] = normalized_embedding text_embeddings[indice] = normalized_embedding
@ -58,18 +67,23 @@ class CacheEmbedding(Embeddings):
"""Embed query text.""" """Embed query text."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
if embedding: if embedding:
return embedding.get_embedding() return embedding.get_embedding()
try: try:
embedding_results = self._embeddings.client.embed_query(text) embedding_result = self._model_instance.invoke_text_embedding(
texts=[text],
user=self._user
)
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise ex
try: try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash) embedding = Embedding(model_name=self._model_instance.model, hash=hash)
embedding.set_embedding(embedding_results) embedding.set_embedding(embedding_results)
db.session.add(embedding) db.session.add(embedding)
db.session.commit() db.session.commit()
@ -79,4 +93,3 @@ class CacheEmbedding(Embeddings):
logging.exception('Failed to add embedding to db') logging.exception('Failed to add embedding to db')
return embedding_results return embedding_results

View File

@ -0,0 +1,265 @@
from enum import Enum
from typing import Optional, Any, cast
from pydantic import BaseModel
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import AIModelEntity
class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Advanced Completion Prompt Template Entity.
"""
class RolePrefixEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
prompt: str
role_prefix: Optional[RolePrefixEntity] = None
class PromptTemplateEntity(BaseModel):
"""
Prompt Template Entity.
"""
class PromptType(Enum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = 'simple'
ADVANCED = 'advanced'
@classmethod
def value_of(cls, value: str) -> 'PromptType':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt type value {value}')
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: dict[str, Any] = {}
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = 'single'
MULTIPLE = 'multiple'
@classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid retrieve strategy value {value}')
query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
single_strategy: Optional[str] = None # for temp
top_k: Optional[int] = None
score_threshold: Optional[float] = None
reranking_model: Optional[dict] = None
class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity
class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: dict[str, Any] = {}
class FileUploadEntity(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
tool_id: str
config: dict[str, Any] = {}
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str
strategy: Strategy
tools: list[AgentToolEntity] = []
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
"""
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
external_data_variables: list[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
dataset: Optional[DatasetEntity] = None
file_upload: Optional[FileUploadEntity] = None
opening_statement: Optional[str] = None
suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
EXPLORE = 'explore'
DEBUGGER = 'debugger'
@classmethod
def value_of(cls, value: str) -> 'InvokeFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
def to_source(self) -> str:
"""
Get source of invoke from.
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
elif self == InvokeFrom.DEBUGGER:
return 'dev'
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
elif self == InvokeFrom.SERVICE_API:
return 'api'
return 'dev'
class ApplicationGenerateEntity(BaseModel):
"""
Application Generate Entity.
"""
task_id: str
tenant_id: str
app_id: str
app_model_config_id: str
# for save
app_model_config_dict: dict
app_model_config_override: bool
# Converted from app_model_config to Entity object, or directly covered by external input
app_orchestration_config_entity: AppOrchestrationConfigEntity
conversation_id: Optional[str] = None
inputs: dict[str, str]
query: Optional[str] = None
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}

View File

@ -0,0 +1,128 @@
import enum
from typing import Any, cast
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \
ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, List[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
if isinstance(message, LCHumanMessageWithFiles):
file_prompt_message_contents = []
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
file_prompt_message_contents.append(ImagePromptMessageContent(
data=file.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
))
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
prompt_message_contents.extend(file_prompt_message_contents)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.content))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content
}
if 'function_call' in message.additional_kwargs:
message_kwargs['tool_calls'] = [
AssistantPromptMessage.ToolCall(
id=message.additional_kwargs['function_call']['id'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.additional_kwargs['function_call']['name'],
arguments=message.additional_kwargs['function_call']['arguments']
)
)
]
prompt_messages.append(AssistantPromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(SystemPromptMessage(content=message.content))
elif isinstance(message, FunctionMessage):
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
return prompt_messages
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, str):
messages.append(HumanMessage(content=prompt_message.content))
else:
message_contents = []
for content in prompt_message.content:
if isinstance(content, TextPromptMessageContent):
message_contents.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
message_contents.append({
'type': 'image',
'data': content.data,
'detail': content.detail.value
})
messages.append(HumanMessage(content=message_contents))
elif isinstance(prompt_message, AssistantPromptMessage):
message_kwargs = {
'content': prompt_message.content
}
if prompt_message.tool_calls:
message_kwargs['additional_kwargs'] = {
'function_call': {
'id': prompt_message.tool_calls[0].id,
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
}
messages.append(AIMessage(**message_kwargs))
elif isinstance(prompt_message, SystemPromptMessage):
messages.append(SystemMessage(content=prompt_message.content))
elif isinstance(prompt_message, ToolPromptMessage):
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
return messages

View File

@ -0,0 +1,71 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ProviderModel, ModelType
from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity
class ModelStatus(Enum):
"""
Enum class for model status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
class SimpleModelProviderEntity(BaseModel):
"""
Simple provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None:
"""
Init simple provider.
:param provider_entity: provider entity
"""
super().__init__(
provider=provider_entity.provider,
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types
)
class ModelWithProviderEntity(ProviderModel):
"""
Model with provider entity.
"""
provider: SimpleModelProviderEntity
status: ModelStatus
class DefaultModelProviderEntity(BaseModel):
"""
Default model provider entity.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
class DefaultModelEntity(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: DefaultModelProviderEntity

View File

@ -0,0 +1,657 @@
import datetime
import json
import time
from json import JSONDecodeError
from typing import Optional, List, Dict, Tuple, Iterator
from pydantic import BaseModel
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.utils import encoders
from extensions.ext_database import db
from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
tenant_id: str
provider: ProviderEntity
preferred_provider_type: ProviderType
using_provider_type: ProviderType
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
:param model_type: model type
:param model: model name
:return:
"""
if self.using_provider_type == ProviderType.SYSTEM:
return self.system_configuration.credentials
else:
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials
if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials
else:
return None
def get_system_configuration_status(self) -> SystemConfigurationStatus:
"""
Get system configuration status.
:return:
"""
if self.system_configuration.enabled is False:
return SystemConfigurationStatus.UNSUPPORTED
current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
None
)
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
SystemConfigurationStatus.QUOTA_EXCEEDED
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Get custom credentials.
:param obfuscated: obfuscated secret data in credentials
:return:
"""
if self.custom_configuration.provider is None:
return None
credentials = self.custom_configuration.provider.credentials
if not obfuscated:
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
if provider_record:
try:
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory.provider_credentials_validate(
self.provider.provider,
credentials
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_record, credentials
def add_or_update_custom_credentials(self, credentials: dict) -> None:
"""
Add or update custom provider credentials.
:param credentials:
:return:
"""
# validate custom provider config
provider_record, credentials = self.custom_credentials_validate(credentials)
# save provider
# Note: Do not switch the preferred provider, which allows users to use quotas first
if provider_record:
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
provider_record.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_record)
db.session.commit()
self.switch_preferred_provider_type(ProviderType.CUSTOM)
def delete_custom_credentials(self) -> None:
"""
Delete custom provider credentials.
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# delete provider
if provider_record:
self.switch_preferred_provider_type(ProviderType.SYSTEM)
db.session.delete(provider_record)
db.session.commit()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
"""
Get custom model credentials.
:param model_type: model type
:param model: model name
:param obfuscated: obfuscated secret data in credentials
:return:
"""
if not self.custom_configuration.models:
return None
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
if not obfuscated:
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> Tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
if provider_model_record:
try:
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
)
model_schema = (
model_provider_factory.get_provider_instance(self.provider.provider)
.get_model_instance(model_type)._get_customizable_model_schema(
model=model,
credentials=credentials
)
)
if model_schema:
credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_model_record, credentials
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
"""
Add or update custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# validate custom model config
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
# save provider model
# Note: Do not switch the preferred provider, which allows users to use quotas first
if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
provider_model_record.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_model_record)
db.session.commit()
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
"""
Delete custom model credentials.
:param model_type: model type
:param model: model name
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
# delete provider model
if provider_model_record:
db.session.delete(provider_model_record)
db.session.commit()
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
"""
Get current model type instance.
:param model_type: model type
:return:
"""
# Get provider instance
provider_instance = self.get_provider_instance()
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
"""
Switch preferred provider type.
:param provider_type:
:return:
"""
if provider_type == self.preferred_provider_type:
return
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return
# get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider
).first()
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value
)
db.session.add(preferred_model_provider)
db.session.commit()
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
"""
Obfuscated credentials.
:param credentials: credentials
:param credential_form_schemas: credential form schemas
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self._extract_secret_variables(
credential_form_schemas
)
# Obfuscate provider credentials
copy_credentials = credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
:param model: model name
:param only_active: return active model only
:return:
"""
provider_models = self.get_provider_models(model_type, only_active)
for provider_model in provider_models:
if provider_model.model == model:
return provider_model
return None
def get_provider_models(self, model_type: Optional[ModelType] = None,
only_active: bool = False) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
:param only_active: only active models
:return:
"""
provider_instance = self.get_provider_instance()
model_types = []
if model_type:
model_types.append(model_type)
else:
model_types = provider_instance.get_provider_schema().supported_model_types
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
if only_active:
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_system_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
:return:
"""
provider_models = []
for model_type in model_types:
provider_models.extend(
[
ModelWithProviderEntity(
**m.dict(),
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
for m in provider_instance.models(model_type)
]
)
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_llms = quota_configuration.restrict_llms
if not restrict_llms:
break
# if llm name not in restricted llm list, remove it
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
:return:
"""
provider_models = []
credentials = None
if self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
for model_type in model_types:
if model_type not in self.provider.supported_model_types:
continue
models = provider_instance.models(model_type)
for m in models:
provider_models.append(
ModelWithProviderEntity(
**m.dict(),
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
)
)
# custom models
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
)
if not custom_model_schema:
continue
provider_models.append(
ModelWithProviderEntity(
**custom_model_schema.dict(),
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
return provider_models
class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
tenant_id: str
configurations: Dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
"""
Get available models.
If preferred provider type is `system`:
Get the current **system mode** if provider supported,
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
If there is no model configured in custom mode, it is treated as no_configure.
system > custom > no_configure
If preferred provider type is `custom`:
If custom credentials are configured, it is treated as custom mode.
Otherwise, get the current **system mode** if supported,
If all system modes are not available (no quota), it is treated as no_configure.
custom > system > no_configure
If real mode is `system`, use system credentials to get models,
paid quotas > provider free quotas > system free quotas
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
If real mode is `custom`, use workspace custom credentials to get models,
include pre-defined models, custom models(manual append).
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
model status marked as `active` is available.
:param provider: provider name
:param model_type: model type
:param only_active: only active models
:return:
"""
all_models = []
for provider_configuration in self.values():
if provider and provider_configuration.provider.provider != provider:
continue
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
return all_models
def to_list(self) -> List[ProviderConfiguration]:
"""
Convert to list.
:return:
"""
return list(self.values())
def __getitem__(self, key):
return self.configurations[key]
def __setitem__(self, key, value):
self.configurations[key] = value
def __iter__(self):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return self.configurations.values()
def get(self, key, default=None):
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True

View File

@ -0,0 +1,67 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
class QuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = 'active'
QUOTA_EXCEEDED = 'quota-exceeded'
UNSUPPORTED = 'unsupported'
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_llms: list[str] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
credentials: dict
class CustomModelConfiguration(BaseModel):
"""
Model class for provider custom model configuration.
"""
model: str
model_type: ModelType
credentials: dict
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []

View File

@ -0,0 +1,118 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
class QueueEvent(Enum):
"""
QueueEvent enum
"""
MESSAGE = "message"
MESSAGE_REPLACE = "message-replace"
MESSAGE_END = "message-end"
RETRIEVER_RESOURCES = "retriever-resources"
ANNOTATION_REPLY = "annotation-reply"
AGENT_THOUGHT = "agent-thought"
ERROR = "error"
PING = "ping"
STOP = "stop"
class AppQueueEvent(BaseModel):
"""
QueueEvent entity
"""
event: QueueEvent
class QueueMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.MESSAGE
chunk: LLMResultChunk
class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event = QueueEvent.MESSAGE_REPLACE
text: str
class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
class AnnotationReplyEvent(AppQueueEvent):
"""
AnnotationReplyEvent entity
"""
event = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
llm_result: LLMResult
class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event = QueueEvent.ERROR
error: Any
class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event = QueueEvent.PING
class QueueStopEvent(AppQueueEvent):
"""
QueueStopEvent entity
"""
class StopBy(Enum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
event = QueueEvent.STOP
stopped_by: StopBy
class QueueMessage(BaseModel):
"""
QueueMessage entity
"""
task_id: str
message_id: str
conversation_id: str
app_mode: str
event: AppQueueEvent

View File

@ -14,26 +14,6 @@ class LLMBadRequestError(LLMError):
description = "Bad Request" description = "Bad Request"
class LLMAPIConnectionError(LLMError):
"""Raised when the LLM returns API connection error."""
description = "API Connection Error"
class LLMAPIUnavailableError(LLMError):
"""Raised when the LLM returns API unavailable error."""
description = "API Unavailable Error"
class LLMRateLimitError(LLMError):
"""Raised when the LLM returns rate limit error."""
description = "Rate Limit Error"
class LLMAuthorizationError(LLMError):
"""Raised when the LLM returns authorization error."""
description = "Authorization Error"
class ProviderTokenNotInitError(Exception): class ProviderTokenNotInitError(Exception):
""" """
Custom exception raised when the provider token is not initialized. Custom exception raised when the provider token is not initialized.

View File

@ -0,0 +1,35 @@
{
"label": {
"en-US": "Weather Search",
"zh-Hans": "天气查询"
},
"form_schema": [
{
"type": "select",
"label": {
"en-US": "Temperature Unit",
"zh-Hans": "温度单位"
},
"variable": "temperature_unit",
"required": true,
"options": [
{
"label": {
"en-US": "Fahrenheit",
"zh-Hans": "华氏度"
},
"value": "fahrenheit"
},
{
"label": {
"en-US": "Centigrade",
"zh-Hans": "摄氏度"
},
"value": "centigrade"
}
],
"default": "centigrade",
"placeholder": "Please select temperature unit"
}
]
}

View File

@ -0,0 +1,45 @@
from typing import Optional
from core.external_data_tool.base import ExternalDataTool
class WeatherSearch(ExternalDataTool):
"""
The name of custom type must be unique, keep the same with directory and file name.
"""
name: str = "weather_search"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
schema.json validation. It will be called when user save the config.
Example:
.. code-block:: python
config = {
"temperature_unit": "centigrade"
}
:param tenant_id: the id of workspace
:param config: the variables of form config
:return:
"""
if not config.get('temperature_unit'):
raise ValueError('temperature unit is required')
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
city = inputs.get('city')
temperature_unit = self.config.get('temperature_unit')
if temperature_unit == 'fahrenheit':
return f'Weather in {city} is 32°F'
else:
return f'Weather in {city} is 0°C'

View File

@ -0,0 +1,325 @@
import logging
from typing import cast, Optional, List
from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.tools import BaseTool, WikipediaQueryRun, Tool
from pydantic import BaseModel, Field
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_tools(self, tool_configs: list[AgentToolEntity],
invoke_from: InvokeFrom,
callbacks: list[BaseCallbackHandler]) \
-> Optional[List[BaseTool]]:
"""
Convert tool configs to tools
:param tool_configs: tool configs
:param invoke_from: invoke from
:param callbacks: callbacks
"""
tools = []
for tool_config in tool_configs:
tool = None
if tool_config.tool_id == "dataset":
tool = self.to_dataset_retriever_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "web_reader":
tool = self.to_web_reader_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "google_search":
tool = self.to_google_search_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "wikipedia":
tool = self.to_wikipedia_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "current_datetime":
tool = self.to_current_datetime_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
if tool:
if tool.callbacks is not None:
tool.callbacks.extend(callbacks)
else:
tool.callbacks = callbacks
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool
def to_web_reader_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for reading web pages
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
model_parameters = {
"temperature": 0,
"max_tokens": 500
}
tool = WebReaderTool(
model_config=self.model_config,
model_parameters=model_parameters,
max_chunk_length=4000,
continue_reading=True
)
return tool
def to_google_search_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for performing a Google search and extracting snippets and webpages
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
return None
tool = Tool(
name="google_search",
description="A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information "
"is not up to date. "
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput
)
return tool
def to_current_datetime_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for getting the current date and time
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
return DatetimeTool()
def to_wikipedia_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for searching Wikipedia
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
return WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput
)

View File

@ -0,0 +1,119 @@
import logging
from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, Message, AppAnnotationSetting, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_record.id).first()
if not annotation_setting:
return None
collection_binding_detail = annotation_setting.collection_binding_detail
try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search(
query=query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
from_source = 'api'
else:
from_source = 'console'
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score)
return annotation
except Exception as e:
logger.warning(f'Query annotation failed, exception: {str(e)}.')
return None
return None

View File

@ -0,0 +1,181 @@
from typing import cast, Optional, List
from langchain.tools import BaseTool
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
class DatasetRetrievalFeature:
def retrieve(self, tenant_id: str,
model_config: ModelConfigEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset.
:param tenant_id: tenant id
:param model_config: model config
:param config: dataset config
:param query: query
:param invoke_from: invoke from
:param show_retrieve_source: show retrieve source
:param hit_callback: hit callback
:param memory: memory
:return:
"""
dataset_ids = config.dataset_ids
retrieve_config = config.retrieve_config
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
dataset_retriever_tools = self.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=show_retrieve_source,
invoke_from=invoke_from,
hit_callback=hit_callback
)
if len(dataset_retriever_tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=model_config,
tools=dataset_retriever_tools,
memory=memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate"
)
agent_executor = AgentExecutor(agent_configuration)
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[List[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
:param dataset_ids: dataset ids
:param retrieve_config: retrieve config
:param return_resource: return resource
:param invoke_from: invoke from
:param hit_callback: hit callback
"""
tools = []
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue
available_datasets.append(dataset)
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source()
)
tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=(retrieve_config.score_threshold or 0.5)
if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
)
tools.append(tool)
return tools

View File

@ -0,0 +1,96 @@
import concurrent
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Optional
from flask import current_app, Flask
from core.entities.application_entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
logger = logging.getLogger(__name__)
class ExternalDataFetchFeature:
def fetch(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
tool_key = (tool.type, json.dumps(tool.config, sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tool in external_data_tools:
future = executor.submit(
self._query_external_data_tool,
current_app._get_current_object(),
tenant_id,
app_id,
tool,
inputs,
query
)
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_variable, result = future.result()
results[tool_variable] = result
inputs.update(results)
return inputs
def _query_external_data_tool(self, flask_app: Flask,
tenant_id: str,
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
query: str) -> Tuple[Optional[str], Optional[str]]:
"""
Query external data tool.
:param flask_app: flask app
:param tenant_id: tenant id
:param app_id: app id
:param external_data_tool: external data tool
:param inputs: inputs
:param query: query
:return:
"""
with flask_app.app_context():
tool_variable = external_data_tool.variable
tool_type = external_data_tool.type
tool_config = external_data_tool.config
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
return tool_variable, result

View File

@ -0,0 +1,32 @@
import logging
from core.entities.application_entities import ApplicationGenerateEntity
from core.helper import moderation
from core.model_runtime.entities.message_entities import PromptMessage
logger = logging.getLogger(__name__)
class HostingModerationFeature:
def check(self, application_generate_entity: ApplicationGenerateEntity,
prompt_messages: list[PromptMessage]) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
:param prompt_messages: prompt messages
:return:
"""
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
model_config = app_orchestration_config.model_config
text = ""
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(
model_config,
text
)
return moderation_result

View File

@ -0,0 +1,50 @@
import logging
from typing import Tuple
from core.entities.application_entities import AppOrchestrationConfigEntity
from core.moderation.base import ModerationAction, ModerationException
from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__)
class ModerationFeature:
def check(self, app_id: str,
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_orchestration_config_entity: app orchestration config entity
:param inputs: inputs
:param query: query
:return:
"""
if not app_orchestration_config_entity.sensitive_word_avoidance:
return False, inputs, query
sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance
moderation_type = sensitive_word_avoidance_config.type
moderation_factory = ModerationFactory(
name=moderation_type,
app_id=app_id,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config.config
)
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return False, inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return True, inputs, query

View File

@ -4,7 +4,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.file.upload_file_parser import UploadFileParser from core.file.upload_file_parser import UploadFileParser
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db from extensions.ext_database import db
from models.model import UploadFile from models.model import UploadFile
@ -50,14 +50,14 @@ class FileObj(BaseModel):
return self._get_data(force_url=True) return self._get_data(force_url=True)
@property @property
def prompt_message_file(self) -> PromptMessageFile: def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE: if self.type == FileType.IMAGE:
image_config = self.file_config.get('image') image_config = self.file_config.get('image')
return ImagePromptMessageFile( return ImagePromptMessageContent(
data=self.data, data=self.data,
detail=ImagePromptMessageFile.DETAIL.HIGH detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
) )
def _get_data(self, force_url: bool = False) -> Optional[str]: def _get_data(self, force_url: bool = False) -> Optional[str]:

View File

@ -3,10 +3,10 @@ import logging
from langchain.schema import OutputParserException from langchain.schema import OutputParserException
from core.model_providers.error import LLMError, ProviderTokenNotInitError from core.model_manager import ModelManager
from core.model_providers.model_factory import ModelFactory from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_runtime.entities.model_entities import ModelType
from core.model_providers.models.entity.model_params import ModelKwargs from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@ -26,17 +26,22 @@ class LLMGenerator:
prompt += query + "\n" prompt += query + "\n"
model_instance = ModelFactory.get_text_generation_model( model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_kwargs=ModelKwargs( model_type=ModelType.LLM,
temperature=1,
max_tokens=100
)
) )
prompts = [PromptMessage(content=prompt)] prompts = [UserPromptMessage(content=prompt)]
response = model_instance.run(prompts) response = model_instance.invoke_llm(
answer = response.content prompt_messages=prompts,
model_parameters={
"max_tokens": 100,
"temperature": 1
},
stream=False
)
answer = response.message.content
result_dict = json.loads(answer) result_dict = json.loads(answer)
answer = result_dict['Your Output'] answer = result_dict['Your Output']
@ -62,22 +67,28 @@ class LLMGenerator:
}) })
try: try:
model_instance = ModelFactory.get_text_generation_model( model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_kwargs=ModelKwargs( model_type=ModelType.LLM,
max_tokens=256,
temperature=0
)
) )
except ProviderTokenNotInitError: except InvokeAuthorizationError:
return [] return []
prompt_messages = [PromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
try: try:
output = model_instance.run(prompt_messages) response = model_instance.invoke_llm(
questions = output_parser.parse(output.content) prompt_messages=prompt_messages,
except LLMError: model_parameters={
"max_tokens": 256,
"temperature": 0
},
stream=False
)
questions = output_parser.parse(response.message.content)
except InvokeError:
questions = [] questions = []
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@ -105,20 +116,26 @@ class LLMGenerator:
remove_template_variables=False remove_template_variables=False
) )
model_instance = ModelFactory.get_text_generation_model( model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_kwargs=ModelKwargs( model_type=ModelType.LLM,
max_tokens=512,
temperature=0
)
) )
prompt_messages = [PromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
try: try:
output = model_instance.run(prompt_messages) response = model_instance.invoke_llm(
rule_config = output_parser.parse(output.content) prompt_messages=prompt_messages,
except LLMError as e: model_parameters={
"max_tokens": 512,
"temperature": 0
},
stream=False
)
rule_config = output_parser.parse(response.message.content)
except InvokeError as e:
raise e raise e
except OutputParserException: except OutputParserException:
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
@ -136,18 +153,24 @@ class LLMGenerator:
def generate_qa_document(cls, tenant_id: str, query, document_language: str): def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language) prompt = GENERATOR_QA_PROMPT.format(language=document_language)
model_instance = ModelFactory.get_text_generation_model( model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
model_kwargs=ModelKwargs( model_type=ModelType.LLM,
max_tokens=2000
)
) )
prompts = [ prompt_messages = [
PromptMessage(content=prompt, type=MessageType.SYSTEM), SystemPromptMessage(content=prompt),
PromptMessage(content=query) UserPromptMessage(content=query)
] ]
response = model_instance.run(prompts) response = model_instance.invoke_llm(
answer = response.content prompt_messages=prompt_messages,
model_parameters={
"max_tokens": 2000
},
stream=False
)
answer = response.message.content
return answer.strip() return answer.strip()

View File

@ -18,3 +18,17 @@ def encrypt_token(tenant_id: str, token: str):
def decrypt_token(tenant_id: str, token: str): def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id) return rsa.decrypt(base64.b64decode(token), tenant_id)
def batch_decrypt_token(tenant_id: str, tokens: list[str]):
rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)
return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]
def get_decrypt_decoding(tenant_id: str):
return rsa.get_decrypt_decoding(tenant_id)
def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)

View File

@ -0,0 +1,22 @@
from collections import OrderedDict
from typing import Any
class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:
if key not in self.cache:
return None
else:
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
return self.cache[key]
def put(self, key: Any, value: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False) # pop the first item

View File

@ -1,18 +1,27 @@
import logging import logging
import random import random
import openai from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_providers.error import LLMBadRequestError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from core.model_providers.providers.base import BaseModelProvider from extensions.ext_hosting_provider import hosting_configuration
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
from models.provider import ProviderType from models.provider import ProviderType
logger = logging.getLogger(__name__)
def check_moderation(model_config: ModelConfigEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config
if (moderation_config and moderation_config.enabled is True
and 'openai' in hosting_configuration.provider_map
and hosting_configuration.provider_map['openai'].enabled is True
):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM \
and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map['openai']
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in hosted_config.moderation.providers:
# 2000 text per chunk # 2000 text per chunk
length = 2000 length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)] text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
text_chunk = random.choice(text_chunks) text_chunk = random.choice(text_chunks)
try: try:
moderation_result = openai.Moderation.create(input=text_chunk, model_type_instance = OpenAIModerationModel()
api_key=hosted_model_providers.openai.api_key) moderation_result = model_type_instance.invoke(
model='text-moderation-stable',
credentials=hosting_openai_config.credentials,
text=text_chunk
)
if moderation_result is True:
return True
except Exception as ex: except Exception as ex:
logging.exception(ex) logger.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.') raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results: return False
if result['flagged'] is True:
return False
return True

View File

@ -0,0 +1,213 @@
import os
from typing import Optional
from flask import Flask
from pydantic import BaseModel
from core.entities.provider_entities import QuotaUnit
from models.provider import ProviderQuotaType
class HostingQuota(BaseModel):
quota_type: ProviderQuotaType
restrict_llms: list[str] = []
class TrialHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
quota_limit: int = 0
"""Quota limit for the hosting provider models. -1 means unlimited."""
class PaidHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
stripe_price_id: str = None
increase_quota: int = 1
min_quantity: int = 20
max_quantity: int = 100
class FreeHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
class HostingProvider(BaseModel):
enabled: bool = False
credentials: Optional[dict] = None
quota_unit: Optional[QuotaUnit] = None
quotas: list[HostingQuota] = []
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask):
if app.config.get('EDITION') != 'CLOUD':
return
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai()
self.moderation_config = self.init_moderation_config()
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
credentials = {
"openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
}
if os.environ.get("HOSTED_OPENAI_API_BASE"):
credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
quotas = []
hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_llms=[
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-16k",
"text-davinci-003"
]
)
quotas.append(trial_quota)
if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
"HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
paid_quota = PaidHostingQuota(
stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
)
quotas.append(paid_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
credentials = {
"anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
}
if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
quotas = []
hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit
)
quotas.append(trial_quota)
if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
"HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
paid_quota = PaidHostingQuota(
stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
)
quotas.append(paid_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_minimax(self) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_spark(self) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_zhipuai(self) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_moderation_config(self) -> HostedModerationConfig:
if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
return HostedModerationConfig(
enabled=True,
providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
)
return HostedModerationConfig(
enabled=False
)

View File

@ -1,18 +1,12 @@
import json
from flask import current_app from flask import current_app
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_manager import ModelManager
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding from core.model_runtime.entities.model_entities import ModelType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.dataset import Dataset from models.dataset import Dataset
from models.provider import Provider, ProviderType
class IndexBuilder: class IndexBuilder:
@ -22,10 +16,12 @@ class IndexBuilder:
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None return None
embedding_model = ModelFactory.get_embedding_model( model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING,
model_name=dataset.embedding_model provider=dataset.embedding_model_provider,
model=dataset.embedding_model
) )
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)

View File

@ -18,9 +18,11 @@ from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatasetDocumentStore from core.docstore.dataset_docstore import DatasetDocumentStore
from core.generator.llm_generator import LLMGenerator from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.model_providers.error import ProviderTokenNotInitError from core.model_manager import ModelManager
from core.model_providers.model_factory import ModelFactory from core.errors.error import ProviderTokenNotInitError
from core.model_providers.models.entity.message import MessageType from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -36,6 +38,7 @@ class IndexingRunner:
def __init__(self): def __init__(self):
self.storage = storage self.storage = storage
self.model_manager = ModelManager()
def run(self, dataset_documents: List[DatasetDocument]): def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process.""" """Run the indexing process."""
@ -210,7 +213,7 @@ class IndexingRunner:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = None embedding_model_instance = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
@ -218,15 +221,17 @@ class IndexingRunner:
if not dataset: if not dataset:
raise ValueError('Dataset not found.') raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
else: else:
if indexing_technique == 'high_quality': if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
) )
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
@ -255,32 +260,56 @@ class IndexingRunner:
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model: if indexing_technique == 'high_quality' or embedding_model_instance:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[self.filter_string(document.page_content)]
)
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model( model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id tenant_id=tenant_id,
model_type=ModelType.LLM
) )
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language) doc_language)
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return { return {
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(price_info.total_amount),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), "currency": price_info.currency,
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
} }
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0, "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
"currency": embedding_model.get_currency() if embedding_model else 'USD', "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
"preview": preview_texts "preview": preview_texts
} }
@ -290,7 +319,7 @@ class IndexingRunner:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = None embedding_model_instance = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
@ -298,15 +327,17 @@ class IndexingRunner:
if not dataset: if not dataset:
raise ValueError('Dataset not found.') raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
else: else:
if indexing_technique == 'high_quality': if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING
) )
# load data from notion # load data from notion
tokens = 0 tokens = 0
@ -349,35 +380,63 @@ class IndexingRunner:
processing_rule=processing_rule processing_rule=processing_rule
) )
total_segments += len(documents) total_segments += len(documents)
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model: if indexing_technique == 'high_quality' or embedding_model_instance:
tokens += embedding_model.get_num_tokens(document.page_content) tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[document.page_content]
)
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model( model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id tenant_id=tenant_id,
model_type=ModelType.LLM
) )
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language) doc_language)
document_qa_list = self.format_split_text(response) document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return { return {
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(price_info.total_amount),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), "currency": price_info.currency,
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
} }
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0, "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
"currency": embedding_model.get_currency() if embedding_model else 'USD', "currency": embedding_price_info.currency if embedding_model_instance else 'USD',
"preview": preview_texts "preview": preview_texts
} }
@ -656,25 +715,36 @@ class IndexingRunner:
""" """
vector_index = IndexBuilder.get_index(dataset, 'high_quality') vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy') keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = None embedding_model_instance = None
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
chunk_size = 100 chunk_size = 100
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
for i in range(0, len(documents), chunk_size): for i in range(0, len(documents), chunk_size):
# check document is paused # check document is paused
self._check_document_paused_status(dataset_document.id) self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size] chunk_documents = documents[i:i + chunk_size]
if dataset.indexing_technique == 'high_quality' or embedding_model: if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
tokens += sum( tokens += sum(
embedding_model.get_num_tokens(document.page_content) embedding_model_type_instance.get_num_tokens(
embedding_model_instance.model,
embedding_model_instance.credentials,
[document.page_content]
)
for document in chunk_documents for document in chunk_documents
) )

View File

@ -1,95 +0,0 @@
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage
from core.file.message_file_parser import MessageFileParser
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db
from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
model_instance: BaseLLM
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@property
def buffer(self) -> List[BaseMessage]:
"""String buffer of memory."""
app_model = self.conversation.app
# fetch limited messages desc, and return reversed
messages = db.session.query(Message).filter(
Message.conversation_id == self.conversation.id,
Message.answer != ''
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
chat_messages: List[PromptMessage] = []
for message in messages:
files = message.message_files
if files:
file_objs = message_file_parser.transform_message_files(
files, message.app_model_config
)
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
chat_messages.append(PromptMessage(
content=message.query,
type=MessageType.USER,
files=prompt_message_files
))
else:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:
return []
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
return to_lc_messages(chat_messages)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer: Any = self.buffer
if self.return_messages:
final_buffer: Any = buffer
else:
final_buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

View File

@ -1,36 +0,0 @@
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
memory: ReadOnlyConversationTokenDBBufferSharedMemory
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return self.memory.memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load memory variables from memory."""
buffer: Any = self.memory.buffer
final_buffer = get_buffer_string(
buffer,
human_prefix=self.memory.human_prefix,
ai_prefix=self.memory.ai_prefix,
)
return {self.memory.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

View File

@ -0,0 +1,109 @@
from core.file.message_file_parser import MessageFileParser
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \
AssistantPromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import model_provider_factory
from extensions.ext_database import db
from models.model import Conversation, Message
class TokenBufferMemory:
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
self.conversation = conversation
self.model_instance = model_instance
def get_history_prompt_messages(self, max_token_limit: int = 2000,
message_limit: int = 10) -> list[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
"""
app_record = self.conversation.app
# fetch limited messages, and return reversed
messages = db.session.query(Message).filter(
Message.conversation_id == self.conversation.id,
Message.answer != ''
).order_by(Message.created_at.desc()).limit(message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(
tenant_id=app_record.tenant_id,
app_id=app_record.id
)
prompt_messages = []
for message in messages:
files = message.message_files
if files:
file_objs = message_file_parser.transform_message_files(
files, message.app_model_config
)
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.query))
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages:
return []
# prune the chat message if it exceeds the max token limit
provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_instance.model,
self.model_instance.credentials,
prompt_messages
)
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and prompt_messages:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_instance.model,
self.model_instance.credentials,
prompt_messages
)
return prompt_messages
def get_history_prompt_text(self, human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int = 10) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
:param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit
)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

209
api/core/model_manager.py Normal file
View File

@ -0,0 +1,209 @@
from typing import Optional, Union, Generator, cast, List, IO
from core.entities.provider_configuration import ProviderModelBundle
from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.provider_manager import ProviderManager
class ModelInstance:
"""
Model instance class
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
self._provider_model_bundle = provider_model_bundle
self.model = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle
:param model: model name
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
return credentials
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception(f"Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception(f"Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception(f"Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
)
def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception(f"Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
text=text,
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception(f"Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
file=file,
user=user
)
class ModelManager:
def __init__(self) -> None:
self._provider_manager = ProviderManager()
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
:param tenant_id: tenant id
:param model_type: model type
:return:
"""
default_model_entity = self._provider_manager.get_default_model(
tenant_id=tenant_id,
model_type=model_type
)
if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
return self.get_model_instance(
tenant_id=tenant_id,
provider=default_model_entity.provider.provider,
model_type=model_type,
model=default_model_entity.model
)

View File

@ -1,335 +0,0 @@
from typing import Optional
from langchain.callbacks.base import Callbacks
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
class ModelFactory:
@classmethod
def get_text_generation_model_from_model_config(cls, tenant_id: str,
model_config: dict,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
provider_name = model_config.get("provider")
model_name = model_config.get("name")
completion_params = model_config.get("completion_params", {})
return cls.get_text_generation_model(
tenant_id=tenant_id,
model_provider_name=provider_name,
model_name=model_name,
model_kwargs=ModelKwargs(
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1)
),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_text_generation_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None,
deduct_quota: bool = True) -> Optional[BaseLLM]:
"""
get text generation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:param model_kwargs:
:param streaming:
:param callbacks:
:param deduct_quota:
:return:
"""
is_default_model = False
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default System Reasoning Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
is_default_model = True
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init text generation model
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
try:
model_instance = model_class(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming,
callbacks=callbacks
)
except LLMBadRequestError as e:
if is_default_model:
raise LLMBadRequestError(f"Default model {model_name} is not available. "
f"Please check your model provider credentials.")
else:
raise e
if is_default_model or not deduct_quota:
model_instance.deduct_quota = False
return model_instance
@classmethod
def get_embedding_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
"""
get embedding model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Embedding Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init embedding model
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_reranking_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseReranking]:
"""
get reranking model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if (model_provider_name is None or len(model_provider_name) == 0) and (model_name is None or len(model_name) == 0):
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Reranking Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init reranking model
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
"""
get speech to text model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Speech-to-Text Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init speech to text model
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseModeration]:
"""
get moderation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init moderation model
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
"""
get default model of model type.
:param tenant_id:
:param model_type:
:return:
"""
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if not default_model:
model_provider_rules = ModelProviderFactory.get_provider_rules()
for model_provider_name, model_provider_rule in model_provider_rules.items():
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
continue
model_list = model_provider.get_supported_model_list(model_type)
if model_list:
model_info = model_list[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=model_provider_name,
model_name=model_info['id']
)
db.session.add(default_model)
db.session.commit()
break
return default_model
@classmethod
def update_default_model(cls,
tenant_id: str,
model_type: ModelType,
provider_name: str,
model_name: str) -> TenantDefaultModel:
"""
update default model of model type.
:param tenant_id:
:param model_type:
:param provider_name:
:param model_name:
:return:
"""
model_provider_name = ModelProviderFactory.get_provider_names()
if provider_name not in model_provider_name:
raise ValueError(f'Invalid provider name: {provider_name}')
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
model_list = model_provider.get_supported_model_list(model_type)
model_ids = [model['id'] for model in model_list]
if model_name not in model_ids:
raise ValueError(f'Invalid model name: {model_name}')
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if default_model:
# update default model
default_model.provider_name = provider_name
default_model.model_name = model_name
db.session.commit()
else:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=provider_name,
model_name=model_name,
)
db.session.add(default_model)
db.session.commit()
return default_model

View File

@ -1,276 +0,0 @@
from typing import Type
from sqlalchemy.exc import IntegrityError
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.rules import provider_rules
from extensions.ext_database import db
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
DEFAULT_MODELS = {
ModelType.TEXT_GENERATION.value: {
'provider_name': 'openai',
'model_name': 'gpt-3.5-turbo',
},
ModelType.EMBEDDINGS.value: {
'provider_name': 'openai',
'model_name': 'text-embedding-ada-002',
},
ModelType.SPEECH_TO_TEXT.value: {
'provider_name': 'openai',
'model_name': 'whisper-1',
}
}
class ModelProviderFactory:
@classmethod
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
if provider_name == 'openai':
from core.model_providers.providers.openai_provider import OpenAIProvider
return OpenAIProvider
elif provider_name == 'anthropic':
from core.model_providers.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider
elif provider_name == 'minimax':
from core.model_providers.providers.minimax_provider import MinimaxProvider
return MinimaxProvider
elif provider_name == 'spark':
from core.model_providers.providers.spark_provider import SparkProvider
return SparkProvider
elif provider_name == 'tongyi':
from core.model_providers.providers.tongyi_provider import TongyiProvider
return TongyiProvider
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'zhipuai':
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
return ZhipuAIProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'baichuan':
from core.model_providers.providers.baichuan_provider import BaichuanProvider
return BaichuanProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider
elif provider_name == 'replicate':
from core.model_providers.providers.replicate_provider import ReplicateProvider
return ReplicateProvider
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
elif provider_name == 'xinference':
from core.model_providers.providers.xinference_provider import XinferenceProvider
return XinferenceProvider
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
elif provider_name == 'cohere':
from core.model_providers.providers.cohere_provider import CohereProvider
return CohereProvider
elif provider_name == 'jina':
from core.model_providers.providers.jina_provider import JinaProvider
return JinaProvider
else:
raise NotImplementedError
@classmethod
def get_provider_names(cls):
"""
Returns a list of provider names.
"""
return list(provider_rules.keys())
@classmethod
def get_provider_rules(cls):
"""
Returns a list of provider rules.
:return:
"""
return provider_rules
@classmethod
def get_provider_rule(cls, provider_name: str):
"""
Returns provider rule.
"""
return provider_rules[provider_name]
@classmethod
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred model provider.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:return:
"""
# get preferred provider
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
if not preferred_provider or not preferred_provider.is_valid:
return None
# init model provider
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
return model_provider_class(provider=preferred_provider)
@classmethod
def get_preferred_type_by_preferred_model_provider(cls,
tenant_id: str,
model_provider_name: str,
preferred_model_provider: TenantPreferredModelProvider):
"""
get preferred provider type by preferred model provider.
:param model_provider_name:
:param preferred_model_provider:
:return:
"""
if not preferred_model_provider:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
support_provider_types = model_provider_rules['support_provider_types']
if ProviderType.CUSTOM.value in support_provider_types:
custom_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.is_valid == True
).first()
if custom_provider:
return ProviderType.CUSTOM.value
model_provider = cls.get_model_provider_class(model_provider_name)
if ProviderType.SYSTEM.value in support_provider_types \
and model_provider.is_provider_type_system_supported():
return ProviderType.SYSTEM.value
elif ProviderType.CUSTOM.value in support_provider_types:
return ProviderType.CUSTOM.value
else:
return preferred_model_provider.preferred_provider_type
@classmethod
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
# get preferred provider type
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
# get providers by preferred provider type
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == preferred_provider_type
).all()
no_system_provider = False
if preferred_provider_type == ProviderType.SYSTEM.value:
quota_type_to_provider_dict = {}
for provider in providers:
quota_type_to_provider_dict[provider.quota_type] = provider
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types']:
if quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
elif quota_type == ProviderQuotaType.TRIAL.value:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.SYSTEM.value,
is_valid=True,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=model_provider_rules['system_config']['quota_limit'],
quota_used=0
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value
).first()
if provider.quota_limit == 0:
return None
return provider
no_system_provider = True
if no_system_provider:
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
if providers:
return providers[0]
else:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
return provider
return None
@classmethod
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider type of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == model_provider_name
).first()
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

Some files were not shown because too many files have changed in this diff Show More