mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
parent
e91dd28a76
commit
d069c668f8
58
.github/workflows/api-model-runtime-tests.yml
vendored
Normal file
58
.github/workflows/api-model-runtime-tests.yml
vendored
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
name: Run Pytest
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- deploy/dev
|
||||||
|
- feat/model-runtime
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
env:
|
||||||
|
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||||
|
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
|
||||||
|
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
|
||||||
|
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||||
|
CHATGLM_API_BASE: http://a.abc.com:11451
|
||||||
|
XINFERENCE_SERVER_URL: http://a.abc.com:11451
|
||||||
|
XINFERENCE_GENERATION_MODEL_UID: generate
|
||||||
|
XINFERENCE_CHAT_MODEL_UID: chat
|
||||||
|
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
|
||||||
|
XINFERENCE_RERANK_MODEL_UID: rerank
|
||||||
|
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
|
||||||
|
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||||
|
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
|
||||||
|
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
|
||||||
|
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
|
||||||
|
MOCK_SWITCH: true
|
||||||
|
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Cache pip dependencies
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||||
|
restore-keys: ${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest
|
||||||
|
pip install -r api/requirements.txt
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
38
.github/workflows/api-unit-tests.yml
vendored
38
.github/workflows/api-unit-tests.yml
vendored
|
@ -1,38 +0,0 @@
|
||||||
name: Run Pytest
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- deploy/dev
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Cache pip dependencies
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
|
||||||
restore-keys: ${{ runner.os }}-pip-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install pytest
|
|
||||||
pip install -r api/requirements.txt
|
|
||||||
|
|
||||||
- name: Run pytest
|
|
||||||
run: pytest api/tests/unit_tests
|
|
|
@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
|
||||||
|
|
||||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
||||||
|
|
||||||
|
|
||||||
|
### Provider Integrations
|
||||||
|
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
|
||||||
|
|
||||||
|
|
||||||
### i18n (Internationalization) Support
|
### i18n (Internationalization) Support
|
||||||
|
|
||||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||||
|
|
15
api/.vscode/launch.json
vendored
15
api/.vscode/launch.json
vendored
|
@ -4,6 +4,21 @@
|
||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: Celery",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "celery",
|
||||||
|
"justMyCode": true,
|
||||||
|
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
|
||||||
|
"envFile": "${workspaceFolder}/.env",
|
||||||
|
"env": {
|
||||||
|
"FLASK_APP": "app.py",
|
||||||
|
"FLASK_DEBUG": "1",
|
||||||
|
"GEVENT_SUPPORT": "True"
|
||||||
|
},
|
||||||
|
"console": "integratedTerminal"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Flask",
|
"name": "Python: Flask",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
|
|
|
@ -34,9 +34,6 @@ RUN apt-get update \
|
||||||
COPY --from=base /pkg /usr/local
|
COPY --from=base /pkg /usr/local
|
||||||
COPY . /app/api/
|
COPY . /app/api/
|
||||||
|
|
||||||
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
|
|
||||||
ENV TRANSFORMERS_OFFLINE true
|
|
||||||
|
|
||||||
COPY docker/entrypoint.sh /entrypoint.sh
|
COPY docker/entrypoint.sh /entrypoint.sh
|
||||||
RUN chmod +x /entrypoint.sh
|
RUN chmod +x /entrypoint.sh
|
||||||
|
|
||||||
|
|
34
api/app.py
34
api/app.py
|
@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
|
||||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||||
from gevent import monkey
|
from gevent import monkey
|
||||||
monkey.patch_all()
|
monkey.patch_all()
|
||||||
if os.environ.get("VECTOR_STORE") == 'milvus':
|
# if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||||
import grpc.experimental.gevent
|
import grpc.experimental.gevent
|
||||||
grpc.experimental.gevent.init_gevent()
|
grpc.experimental.gevent.init_gevent()
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
langchain.verbose = True
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
@ -18,9 +21,8 @@ import threading
|
||||||
from flask import Flask, request, Response
|
from flask import Flask, request, Response
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
|
||||||
from core.model_providers.providers import hosted
|
|
||||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||||
ext_database, ext_storage, ext_mail, ext_code_based_extension
|
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_login import login_manager
|
from extensions.ext_login import login_manager
|
||||||
|
|
||||||
|
@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
|
||||||
register_blueprints(app)
|
register_blueprints(app)
|
||||||
register_commands(app)
|
register_commands(app)
|
||||||
|
|
||||||
hosted.init_app(app)
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,6 +95,7 @@ def initialize_extensions(app):
|
||||||
ext_celery.init_app(app)
|
ext_celery.init_app(app)
|
||||||
ext_login.init_app(app)
|
ext_login.init_app(app)
|
||||||
ext_mail.init_app(app)
|
ext_mail.init_app(app)
|
||||||
|
ext_hosting_provider.init_app(app)
|
||||||
ext_sentry.init_app(app)
|
ext_sentry.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
|
||||||
if request.blueprint == 'console':
|
if request.blueprint == 'console':
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
auth_header = request.headers.get('Authorization', '')
|
auth_header = request.headers.get('Authorization', '')
|
||||||
if ' ' not in auth_header:
|
if not auth_header:
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
auth_token = request.args.get('_token')
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
if not auth_token:
|
||||||
auth_scheme = auth_scheme.lower()
|
raise Unauthorized('Invalid Authorization token.')
|
||||||
if auth_scheme != 'bearer':
|
else:
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
if ' ' not in auth_header:
|
||||||
|
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||||
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
|
auth_scheme = auth_scheme.lower()
|
||||||
|
if auth_scheme != 'bearer':
|
||||||
|
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||||
|
|
||||||
decoded = PassportService().verify(auth_token)
|
decoded = PassportService().verify(auth_token)
|
||||||
user_id = decoded.get('user_id')
|
user_id = decoded.get('user_id')
|
||||||
|
|
||||||
|
|
|
@ -12,16 +12,12 @@ import qdrant_client
|
||||||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from flask import current_app, Flask
|
from flask import current_app, Flask
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.index.index import IndexBuilder
|
from core.index.index import IndexBuilder
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_manager import ModelManager
|
||||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
|
||||||
from core.model_providers.providers.hosted import hosted_model_providers
|
|
||||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
|
||||||
from libs.password import password_pattern, valid_password, hash_password
|
from libs.password import password_pattern, valid_password, hash_password
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -327,6 +323,8 @@ def create_qdrant_indexes():
|
||||||
except NotFound:
|
except NotFound:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
page += 1
|
page += 1
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
|
@ -334,19 +332,23 @@ def create_qdrant_indexes():
|
||||||
try:
|
try:
|
||||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||||
try:
|
try:
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model = model_manager.get_model_instance(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
tenant_id=dataset.tenant_id
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
)
|
||||||
dataset.embedding_model = embedding_model.name
|
dataset.embedding_model = embedding_model.model
|
||||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
id='provider_id',
|
id='provider_id',
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
|
|
|
@ -87,7 +87,7 @@ class Config:
|
||||||
# ------------------------
|
# ------------------------
|
||||||
# General Configurations.
|
# General Configurations.
|
||||||
# ------------------------
|
# ------------------------
|
||||||
self.CURRENT_VERSION = "0.3.34"
|
self.CURRENT_VERSION = "0.4.0"
|
||||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||||
self.EDITION = "SELF_HOSTED"
|
self.EDITION = "SELF_HOSTED"
|
||||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||||
|
|
|
@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
|
||||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||||
|
|
||||||
# Import workspace controllers
|
# Import workspace controllers
|
||||||
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
|
from .workspace import workspace, members, model_providers, account, tool_providers, models
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||||
|
|
|
@ -4,6 +4,10 @@ import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
|
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
@ -13,9 +17,7 @@ from controllers.console import api
|
||||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
|
||||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
|
||||||
from events.app_event import app_was_created, app_was_deleted
|
from events.app_event import app_was_created, app_was_deleted
|
||||||
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||||
app_detail_fields_with_site
|
app_detail_fields_with_site
|
||||||
|
@ -73,39 +75,41 @@ class AppListApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
default_model = ModelFactory.get_text_generation_model(
|
provider_manager = ProviderManager()
|
||||||
tenant_id=current_user.current_tenant_id
|
default_model_entity = provider_manager.get_default_model(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||||
default_model = None
|
default_model_entity = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
default_model = None
|
default_model_entity = None
|
||||||
|
|
||||||
if args['model_config'] is not None:
|
if args['model_config'] is not None:
|
||||||
# validate config
|
# validate config
|
||||||
model_config_dict = args['model_config']
|
model_config_dict = args['model_config']
|
||||||
|
|
||||||
# get model provider
|
# get model provider
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
model_manager = ModelManager()
|
||||||
current_user.current_tenant_id,
|
model_instance = model_manager.get_default_model_instance(
|
||||||
model_config_dict["model"]["provider"]
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_provider:
|
if not model_instance:
|
||||||
if not default_model:
|
raise ProviderNotInitializeError(
|
||||||
raise ProviderNotInitializeError(
|
f"No Default System Reasoning Model available. Please configure "
|
||||||
f"No Default System Reasoning Model available. Please configure "
|
f"in the Settings -> Model Provider.")
|
||||||
f"in the Settings -> Model Provider.")
|
else:
|
||||||
else:
|
model_config_dict["model"]["provider"] = model_instance.provider
|
||||||
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
|
model_config_dict["model"]["name"] = model_instance.model
|
||||||
model_config_dict["model"]["name"] = default_model.name
|
|
||||||
|
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
config=model_config_dict,
|
config=model_config_dict,
|
||||||
mode=args['mode']
|
app_mode=args['mode']
|
||||||
)
|
)
|
||||||
|
|
||||||
app = App(
|
app = App(
|
||||||
|
@ -129,21 +133,27 @@ class AppListApi(Resource):
|
||||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||||
|
|
||||||
# get model provider
|
# get model provider
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
model_manager = ModelManager()
|
||||||
current_user.current_tenant_id,
|
|
||||||
app_model_config.model_dict["provider"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_provider:
|
try:
|
||||||
if not default_model:
|
model_instance = model_manager.get_default_model_instance(
|
||||||
raise ProviderNotInitializeError(
|
tenant_id=current_user.current_tenant_id,
|
||||||
f"No Default System Reasoning Model available. Please configure "
|
model_type=ModelType.LLM
|
||||||
f"in the Settings -> Model Provider.")
|
)
|
||||||
else:
|
except ProviderTokenNotInitError:
|
||||||
model_dict = app_model_config.model_dict
|
raise ProviderNotInitializeError(
|
||||||
model_dict['provider'] = default_model.model_provider.provider_name
|
f"No Default System Reasoning Model available. Please configure "
|
||||||
model_dict['name'] = default_model.name
|
f"in the Settings -> Model Provider.")
|
||||||
app_model_config.model = json.dumps(model_dict)
|
|
||||||
|
if not model_instance:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
f"No Default System Reasoning Model available. Please configure "
|
||||||
|
f"in the Settings -> Model Provider.")
|
||||||
|
else:
|
||||||
|
model_dict = app_model_config.model_dict
|
||||||
|
model_dict['provider'] = model_instance.provider
|
||||||
|
model_dict['name'] = model_instance.model
|
||||||
|
app_model_config.model = json.dumps(model_dict)
|
||||||
|
|
||||||
app.name = args['name']
|
app.name = args['name']
|
||||||
app.mode = args['mode']
|
app.mode = args['mode']
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
|
||||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource
|
||||||
from services.audio_service import AudioService
|
from services.audio_service import AudioService
|
||||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||||
|
@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -5,6 +5,10 @@ from typing import Generator, Union
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
|
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.entities.application_entities import InvokeFrom
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
||||||
ProviderModelCurrentlyNotSupportError
|
ProviderModelCurrentlyNotSupportError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.conversation_message_task import PubHandler
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=account,
|
user=account,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='console',
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
is_model_config_override=True
|
is_model_config_override=True
|
||||||
)
|
)
|
||||||
|
@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
|
||||||
|
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
PubHandler.stop(account, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=account,
|
user=account,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='console',
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
is_model_config_override=True
|
is_model_config_override=True
|
||||||
)
|
)
|
||||||
|
@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
|
||||||
|
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
PubHandler.stop(account, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.generator.llm_generator import LLMGenerator
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
|
||||||
|
|
||||||
|
|
||||||
class RuleGenerateApi(Resource):
|
class RuleGenerateApi(Resource):
|
||||||
|
@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
|
|
||||||
return rules
|
return rules
|
||||||
|
|
|
@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
|
||||||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from fields.conversation_fields import message_detail_fields, annotation_fields
|
from fields.conversation_fields import message_detail_fields, annotation_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
|
||||||
app_model = _get_app(app_id, 'completion')
|
app_model = _get_app(app_id, 'completion')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
|
response = CompletionService.generate_more_like_this(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
message_id=message_id,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
return compact_response(response)
|
return compact_response(response)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(
|
yield "data: " + json.dumps(
|
||||||
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|
|
@ -31,7 +31,7 @@ class ModelConfigResource(Resource):
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
config=request.json,
|
config=request.json,
|
||||||
mode=app.mode
|
app_mode=app.mode
|
||||||
)
|
)
|
||||||
|
|
||||||
new_app_model_config = AppModelConfig(
|
new_app_model_config = AppModelConfig(
|
||||||
|
|
|
@ -4,6 +4,8 @@ from flask import request, current_app
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
from controllers.console.apikey import api_key_list, api_key_fields
|
from controllers.console.apikey import api_key_list, api_key_fields
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from flask_restful import Resource, reqparse, marshal, marshal_with
|
from flask_restful import Resource, reqparse, marshal, marshal_with
|
||||||
from werkzeug.exceptions import NotFound, Forbidden
|
from werkzeug.exceptions import NotFound, Forbidden
|
||||||
|
@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
|
||||||
from fields.app_fields import related_app_list
|
from fields.app_fields import related_app_list
|
||||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||||
from fields.document_fields import document_status_fields
|
from fields.document_fields import document_status_fields
|
||||||
|
@ -23,7 +24,6 @@ from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Document
|
from models.dataset import DocumentSegment, Document
|
||||||
from models.model import UploadFile, ApiToken
|
from models.model import UploadFile, ApiToken
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.provider_service import ProviderService
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
|
@ -55,16 +55,20 @@ class DatasetListApi(Resource):
|
||||||
current_user.current_tenant_id, current_user)
|
current_user.current_tenant_id, current_user)
|
||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_service = ProviderService()
|
provider_manager = ProviderManager()
|
||||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
configurations = provider_manager.get_configurations(
|
||||||
ModelType.EMBEDDINGS.value)
|
tenant_id=current_user.current_tenant_id
|
||||||
# if len(valid_model_list) == 0:
|
)
|
||||||
# raise ProviderNotInitializeError(
|
|
||||||
# f"No Embedding Model available. Please configure a valid provider "
|
embedding_models = configurations.get_models(
|
||||||
# f"in the Settings -> Model Provider.")
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
only_active=True
|
||||||
|
)
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for valid_model in valid_model_list:
|
for embedding_model in embedding_models:
|
||||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||||
|
|
||||||
data = marshal(datasets, dataset_detail_fields)
|
data = marshal(datasets, dataset_detail_fields)
|
||||||
for item in data:
|
for item in data:
|
||||||
if item['indexing_technique'] == 'high_quality':
|
if item['indexing_technique'] == 'high_quality':
|
||||||
|
@ -75,6 +79,7 @@ class DatasetListApi(Resource):
|
||||||
item['embedding_available'] = False
|
item['embedding_available'] = False
|
||||||
else:
|
else:
|
||||||
item['embedding_available'] = True
|
item['embedding_available'] = True
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
'data': data,
|
'data': data,
|
||||||
'has_more': len(datasets) == limit,
|
'has_more': len(datasets) == limit,
|
||||||
|
@ -130,13 +135,20 @@ class DatasetApi(Resource):
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
data = marshal(dataset, dataset_detail_fields)
|
data = marshal(dataset, dataset_detail_fields)
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_service = ProviderService()
|
provider_manager = ProviderManager()
|
||||||
# get valid model list
|
configurations = provider_manager.get_configurations(
|
||||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id
|
||||||
ModelType.EMBEDDINGS.value)
|
)
|
||||||
|
|
||||||
|
embedding_models = configurations.get_models(
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
only_active=True
|
||||||
|
)
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for valid_model in valid_model_list:
|
for embedding_model in embedding_models:
|
||||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||||
|
|
||||||
if data['indexing_technique'] == 'high_quality':
|
if data['indexing_technique'] == 'high_quality':
|
||||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||||
if item_model in model_names:
|
if item_model in model_names:
|
||||||
|
|
|
@ -2,8 +2,12 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from flask import request, current_app
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import desc, asc
|
from sqlalchemy import desc, asc
|
||||||
|
@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||||
LLMBadRequestError
|
LLMBadRequestError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.document_fields import document_with_segments_fields, document_fields, \
|
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||||
dataset_and_document_fields, document_status_fields
|
dataset_and_document_fields, document_status_fields
|
||||||
|
@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args['indexing_technique'] == 'high_quality':
|
if args['indexing_technique'] == 'high_quality':
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
tenant_id=current_user.current_tenant_id
|
model_manager.get_default_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except InvokeAuthorizationError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
f"in the Settings -> Model Provider.")
|
f"in the Settings -> Model Provider.")
|
||||||
|
|
|
@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
|
||||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
|
@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
|
@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
|
|
|
@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||||
LLMBadRequestError
|
LLMBadRequestError
|
||||||
from fields.hit_testing_fields import hit_testing_record_fields
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
|
|
|
@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
||||||
NoAudioUploadedError, AudioTooLargeError, \
|
NoAudioUploadedError, AudioTooLargeError, \
|
||||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from services.audio_service import AudioService
|
from services.audio_service import AudioService
|
||||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||||
|
@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||||
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
|
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.conversation_message_task import PubHandler
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
|
@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='console',
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != 'completion':
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
PubHandler.stop(current_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='console',
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
PubHandler.stop(current_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||||
|
|
||||||
from flask import stream_with_context, Response
|
from flask import stream_with_context, Response
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import reqparse, fields, marshal_with
|
from flask_restful import reqparse, marshal_with
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
from werkzeug.exceptions import NotFound, InternalServerError
|
from werkzeug.exceptions import NotFound, InternalServerError
|
||||||
|
|
||||||
|
@ -13,12 +13,14 @@ import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
|
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||||
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
|
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||||
|
NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs.helper import uuid_value, TimestampField
|
from libs.helper import uuid_value
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args['response_mode'] == 'streaming'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
|
response = CompletionService.generate_more_like_this(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
message_id=message_id,
|
||||||
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
return compact_response(response)
|
return compact_response(response)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|
|
@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
||||||
NoAudioUploadedError, AudioTooLargeError, \
|
NoAudioUploadedError, AudioTooLargeError, \
|
||||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from services.audio_service import AudioService
|
from services.audio_service import AudioService
|
||||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||||
|
@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -12,9 +12,10 @@ from controllers.console import api
|
||||||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
from core.conversation_message_task import PubHandler
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='console',
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
is_model_config_override=True,
|
is_model_config_override=True,
|
||||||
)
|
)
|
||||||
|
@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||||
|
|
||||||
class UniversalChatStopApi(UniversalChatResource):
|
class UniversalChatStopApi(UniversalChatResource):
|
||||||
def post(self, universal_app, task_id):
|
def post(self, universal_app, task_id):
|
||||||
PubHandler.stop(current_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
|
|
@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.helper import uuid_value, TimestampField
|
from libs.helper import uuid_value, TimestampField
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||||
|
@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
|
import io
|
||||||
|
|
||||||
|
from flask import send_file
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from libs.login import login_required
|
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from services.provider_service import ProviderService
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from libs.login import login_required
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderListApi(Resource):
|
class ModelProviderListApi(Resource):
|
||||||
|
@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
|
parser.add_argument('model_type', type=str, required=False, nullable=True,
|
||||||
|
choices=[mt.value for mt in ModelType], location='args')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
|
provider_list = model_provider_service.get_provider_list(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=args.get('model_type')
|
||||||
|
)
|
||||||
|
|
||||||
return provider_list
|
return jsonable_encoder({"data": provider_list})
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderCredentialApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider: str):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
credentials = model_provider_service.get_provider_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"credentials": credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderValidateApi(Resource):
|
class ModelProviderValidateApi(Resource):
|
||||||
|
@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_name: str):
|
def post(self, provider: str):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
result = True
|
result = True
|
||||||
error = None
|
error = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
provider_service.custom_provider_config_validate(
|
model_provider_service.provider_credentials_validate(
|
||||||
provider_name=provider_name,
|
tenant_id=tenant_id,
|
||||||
config=args['config']
|
provider=provider,
|
||||||
|
credentials=args['credentials']
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
|
@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderUpdateApi(Resource):
|
class ModelProviderApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_name: str):
|
def post(self, provider: str):
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
provider_service.save_custom_provider_config(
|
model_provider_service.save_provider_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider_name=provider_name,
|
provider=provider,
|
||||||
config=args['config']
|
credentials=args['credentials']
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ValueError(str(ex))
|
raise ValueError(str(ex))
|
||||||
|
@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider_name: str):
|
def delete(self, provider: str):
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
provider_service.delete_custom_provider(
|
model_provider_service.remove_provider_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider_name=provider_name
|
provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {'result': 'success'}, 204
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelValidateApi(Resource):
|
class ModelProviderIconApi(Resource):
|
||||||
|
"""
|
||||||
|
Get model provider icon
|
||||||
|
"""
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_name: str):
|
def get(self, provider: str, icon_type: str, lang: str):
|
||||||
parser = reqparse.RequestParser()
|
model_provider_service = ModelProviderService()
|
||||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
provider=provider,
|
||||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
icon_type=icon_type,
|
||||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
lang=lang
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
|
|
||||||
result = True
|
|
||||||
error = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_service.custom_provider_model_config_validate(
|
|
||||||
provider_name=provider_name,
|
|
||||||
model_name=args['model_name'],
|
|
||||||
model_type=args['model_type'],
|
|
||||||
config=args['config']
|
|
||||||
)
|
|
||||||
except CredentialsValidateFailedError as ex:
|
|
||||||
result = False
|
|
||||||
error = str(ex)
|
|
||||||
|
|
||||||
response = {'result': 'success' if result else 'error'}
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
response['error'] = error
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelUpdateApi(Resource):
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def post(self, provider_name: str):
|
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
|
||||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
|
||||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_service.add_or_save_custom_provider_model_config(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider_name=provider_name,
|
|
||||||
model_name=args['model_name'],
|
|
||||||
model_type=args['model_type'],
|
|
||||||
config=args['config']
|
|
||||||
)
|
|
||||||
except CredentialsValidateFailedError as ex:
|
|
||||||
raise ValueError(str(ex))
|
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def delete(self, provider_name: str):
|
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
|
||||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
provider_service.delete_custom_provider_model(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider_name=provider_name,
|
|
||||||
model_name=args['model_name'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||||
|
|
||||||
|
|
||||||
class PreferredProviderTypeUpdateApi(Resource):
|
class PreferredProviderTypeUpdateApi(Resource):
|
||||||
|
@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_name: str):
|
def post(self, provider: str):
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
||||||
choices=['system', 'custom'], location='json')
|
choices=['system', 'custom'], location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
provider_service.switch_preferred_provider(
|
model_provider_service.switch_preferred_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_name=provider_name,
|
provider=provider,
|
||||||
preferred_provider_type=args['preferred_provider_type']
|
preferred_provider_type=args['preferred_provider_type']
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelParameterRuleApi(Resource):
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def get(self, provider_name: str):
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
|
|
||||||
try:
|
|
||||||
parameter_rules = provider_service.get_model_parameter_rules(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
model_provider_name=provider_name,
|
|
||||||
model_name=args['model_name'],
|
|
||||||
model_type='text-generation'
|
|
||||||
)
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
f"Current Text Generation Model is invalid. Please switch to the available model.")
|
|
||||||
|
|
||||||
rules = {
|
|
||||||
k: {
|
|
||||||
'enabled': v.enabled,
|
|
||||||
'min': v.min,
|
|
||||||
'max': v.max,
|
|
||||||
'default': v.default,
|
|
||||||
'precision': v.precision
|
|
||||||
}
|
|
||||||
for k, v in vars(parameter_rules).items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return rules
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_name: str):
|
def get(self, provider: str):
|
||||||
if provider_name != 'anthropic':
|
if provider != 'anthropic':
|
||||||
raise ValueError(f'provider name {provider_name} is invalid')
|
raise ValueError(f'provider name {provider} is invalid')
|
||||||
|
|
||||||
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
|
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
account_id=current_user.id)
|
account_id=current_user.id)
|
||||||
return data
|
return data
|
||||||
|
@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_name: str):
|
def post(self, provider: str):
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
result = provider_service.free_quota_submit(
|
result = model_provider_service.free_quota_submit(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider_name=provider_name
|
provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_name: str):
|
def get(self, provider: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
result = provider_service.free_quota_qualification_verify(
|
result = model_provider_service.free_quota_qualification_verify(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider_name=provider_name,
|
provider=provider,
|
||||||
token=args['token']
|
token=args['token']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
|
||||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
|
||||||
api.add_resource(ModelProviderModelValidateApi,
|
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
|
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
|
||||||
api.add_resource(ModelProviderModelUpdateApi,
|
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/models')
|
'<string:icon_type>/<string:lang>')
|
||||||
|
|
||||||
api.add_resource(PreferredProviderTypeUpdateApi,
|
api.add_resource(PreferredProviderTypeUpdateApi,
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
|
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
|
||||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
|
|
||||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
'/workspaces/current/model-providers/<string:provider>/checkout-url')
|
||||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
|
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
|
||||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
||||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
|
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from libs.login import login_required
|
from flask_restful import reqparse, Resource
|
||||||
from flask_restful import Resource, reqparse
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from models.provider import ProviderType
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from services.provider_service import ProviderService
|
from libs.login import login_required
|
||||||
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
|
|
||||||
class DefaultModelApi(Resource):
|
class DefaultModelApi(Resource):
|
||||||
|
@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
choices=[mt.value for mt in ModelType], location='args')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
default_model = provider_service.get_default_model_of_model_type(
|
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=args['model_type']
|
model_type=args['model_type']
|
||||||
)
|
)
|
||||||
|
|
||||||
if not default_model:
|
return jsonable_encoder({
|
||||||
return None
|
"data": default_model_entity
|
||||||
|
})
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
|
||||||
tenant_id,
|
|
||||||
default_model.provider_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
return {
|
|
||||||
'model_name': default_model.model_name,
|
|
||||||
'model_type': default_model.model_type,
|
|
||||||
'model_provider': {
|
|
||||||
'provider_name': default_model.provider_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
provider = model_provider.provider
|
|
||||||
rst = {
|
|
||||||
'model_name': default_model.model_name,
|
|
||||||
'model_type': default_model.model_type,
|
|
||||||
'model_provider': {
|
|
||||||
'provider_name': provider.provider_name,
|
|
||||||
'provider_type': provider.provider_type
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
|
|
||||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
|
||||||
rst['model_provider']['quota_type'] = provider.quota_type
|
|
||||||
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
|
|
||||||
rst['model_provider']['quota_limit'] = provider.quota_limit
|
|
||||||
rst['model_provider']['quota_used'] = provider.quota_used
|
|
||||||
|
|
||||||
return rst
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
|
||||||
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
|
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
provider_service = ProviderService()
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
model_settings = args['model_settings']
|
model_settings = args['model_settings']
|
||||||
for model_setting in model_settings:
|
for model_setting in model_settings:
|
||||||
|
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
|
||||||
|
raise ValueError('invalid model type')
|
||||||
|
|
||||||
|
if 'provider' not in model_setting:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 'model' not in model_setting:
|
||||||
|
raise ValueError('invalid model')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
provider_service.update_default_model_of_model_type(
|
model_provider_service.update_default_model_of_model_type(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=model_setting['model_type'],
|
model_type=model_setting['model_type'],
|
||||||
provider_name=model_setting['provider_name'],
|
provider=model_setting['provider'],
|
||||||
model_name=model_setting['model_name']
|
model=model_setting['model']
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning(f"{model_setting['model_type']} save error")
|
logging.warning(f"{model_setting['model_type']} save error")
|
||||||
|
@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
|
||||||
return {'result': 'success'}
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
class ValidModelApi(Resource):
|
class ModelProviderModelApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
models = model_provider_service.get_models_by_provider(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonable_encoder({
|
||||||
|
"data": models
|
||||||
|
})
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider: str):
|
||||||
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType], location='json')
|
||||||
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_provider_service.save_model_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model=args['model'],
|
||||||
|
model_type=args['model_type'],
|
||||||
|
credentials=args['credentials']
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, provider: str):
|
||||||
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType], location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
model_provider_service.remove_model_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model=args['model'],
|
||||||
|
model_type=args['model_type']
|
||||||
|
)
|
||||||
|
|
||||||
|
return {'result': 'success'}, 204
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderModelCredentialApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider: str):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
||||||
|
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType], location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
credentials = model_provider_service.get_model_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model_type=args['model_type'],
|
||||||
|
model=args['model']
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"credentials": credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderModelValidateApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider: str):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType], location='json')
|
||||||
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
|
result = True
|
||||||
|
error = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_provider_service.model_credentials_validate(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model=args['model'],
|
||||||
|
model_type=args['model_type'],
|
||||||
|
credentials=args['credentials']
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
result = False
|
||||||
|
error = str(ex)
|
||||||
|
|
||||||
|
response = {'result': 'success' if result else 'error'}
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
response['error'] = error
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderModelParameterRuleApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
model_provider_service = ModelProviderService()
|
||||||
|
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model=args['model']
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonable_encoder({
|
||||||
|
"data": parameter_rules
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderAvailableModelApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, model_type):
|
def get(self, model_type):
|
||||||
ModelType.value_of(model_type)
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
provider_service = ProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
valid_models = provider_service.get_valid_model_list(
|
models = model_provider_service.get_models_by_model_type(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=model_type
|
model_type=model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
return valid_models
|
return jsonable_encoder({
|
||||||
|
"data": models
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
|
||||||
|
api.add_resource(ModelProviderModelCredentialApi,
|
||||||
|
'/workspaces/current/model-providers/<string:provider>/models/credentials')
|
||||||
|
api.add_resource(ModelProviderModelValidateApi,
|
||||||
|
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
|
||||||
|
|
||||||
|
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||||
|
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
|
||||||
|
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
|
||||||
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
||||||
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
|
|
||||||
|
|
|
@ -1,131 +0,0 @@
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
from flask_login import current_user
|
|
||||||
from libs.login import login_required
|
|
||||||
from flask_restful import Resource, reqparse
|
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from controllers.console import api
|
|
||||||
from controllers.console.setup import setup_required
|
|
||||||
from controllers.console.wraps import account_initialization_required
|
|
||||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
|
||||||
from models.provider import ProviderType
|
|
||||||
from services.provider_service import ProviderService
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderListApi(Resource):
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def get(self):
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
"""
|
|
||||||
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
|
|
||||||
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
|
|
||||||
rest is replaced by * and the last two bits are displayed in plaintext
|
|
||||||
|
|
||||||
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
|
|
||||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
|
||||||
"""
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
provider_info_list = provider_service.get_provider_list(tenant_id)
|
|
||||||
|
|
||||||
provider_list = [
|
|
||||||
{
|
|
||||||
'provider_name': p['provider_name'],
|
|
||||||
'provider_type': p['provider_type'],
|
|
||||||
'is_valid': p['is_valid'],
|
|
||||||
'last_used': p['last_used'],
|
|
||||||
'is_enabled': p['is_valid'],
|
|
||||||
**({
|
|
||||||
'quota_type': p['quota_type'],
|
|
||||||
'quota_limit': p['quota_limit'],
|
|
||||||
'quota_used': p['quota_used']
|
|
||||||
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
|
|
||||||
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
|
|
||||||
if p['config'] else None
|
|
||||||
}
|
|
||||||
for name, provider_info in provider_info_list.items()
|
|
||||||
for p in provider_info['providers']
|
|
||||||
]
|
|
||||||
|
|
||||||
return provider_list
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderTokenApi(Resource):
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def post(self, provider):
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if provider == 'openai':
|
|
||||||
args['token'] = {
|
|
||||||
'openai_api_key': args['token']
|
|
||||||
}
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
try:
|
|
||||||
provider_service.save_custom_provider_config(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider_name=provider,
|
|
||||||
config=args['token']
|
|
||||||
)
|
|
||||||
except CredentialsValidateFailedError as ex:
|
|
||||||
raise ValueError(str(ex))
|
|
||||||
|
|
||||||
return {'result': 'success'}, 201
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderTokenValidateApi(Resource):
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def post(self, provider):
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
provider_service = ProviderService()
|
|
||||||
|
|
||||||
if provider == 'openai':
|
|
||||||
args['token'] = {
|
|
||||||
'openai_api_key': args['token']
|
|
||||||
}
|
|
||||||
|
|
||||||
result = True
|
|
||||||
error = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_service.custom_provider_config_validate(
|
|
||||||
provider_name=provider,
|
|
||||||
config=args['token']
|
|
||||||
)
|
|
||||||
except CredentialsValidateFailedError as ex:
|
|
||||||
result = False
|
|
||||||
error = str(ex)
|
|
||||||
|
|
||||||
response = {'result': 'success' if result else 'error'}
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
response['error'] = error
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
|
|
||||||
endpoint='workspaces_current_providers_token') # PUT for updating provider token
|
|
||||||
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
|
|
||||||
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
|
|
||||||
|
|
||||||
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
|
|
|
@ -34,7 +34,6 @@ tenant_fields = {
|
||||||
'status': fields.String,
|
'status': fields.String,
|
||||||
'created_at': TimestampField,
|
'created_at': TimestampField,
|
||||||
'role': fields.String,
|
'role': fields.String,
|
||||||
'providers': fields.List(fields.Nested(provider_fields)),
|
|
||||||
'in_trial': fields.Boolean,
|
'in_trial': fields.Boolean,
|
||||||
'trial_end_reason': fields.String,
|
'trial_end_reason': fields.String,
|
||||||
'custom_config': fields.Raw(attribute='custom_config'),
|
'custom_config': fields.Raw(attribute='custom_config'),
|
||||||
|
|
|
@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
||||||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
||||||
ProviderNotSupportSpeechToTextError
|
ProviderNotSupportSpeechToTextError
|
||||||
from controllers.service_api.wraps import AppApiResource
|
from controllers.service_api.wraps import AppApiResource
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from models.model import App, AppModelConfig
|
from models.model import App, AppModelConfig
|
||||||
from services.audio_service import AudioService
|
from services.audio_service import AudioService
|
||||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||||
|
@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
||||||
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
|
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||||
ProviderModelCurrentlyNotSupportError
|
ProviderModelCurrentlyNotSupportError
|
||||||
from controllers.service_api.wraps import AppApiResource
|
from controllers.service_api.wraps import AppApiResource
|
||||||
from core.conversation_message_task import PubHandler
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
|
|
||||||
|
@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='api',
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != 'completion':
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
||||||
PubHandler.stop(end_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='api',
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
PubHandler.stop(end_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
|
|
@ -4,11 +4,11 @@ import services.dataset_service
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||||
from controllers.service_api.wraps import DatasetApiResource
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.provider_service import ProviderService
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
|
@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
|
||||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||||
tenant_id, current_user)
|
tenant_id, current_user)
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_service = ProviderService()
|
provider_manager = ProviderManager()
|
||||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
configurations = provider_manager.get_configurations(
|
||||||
ModelType.EMBEDDINGS.value)
|
tenant_id=current_user.current_tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_models = configurations.get_models(
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
only_active=True
|
||||||
|
)
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for valid_model in valid_model_list:
|
for embedding_model in embedding_models:
|
||||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||||
|
|
||||||
data = marshal(datasets, dataset_detail_fields)
|
data = marshal(datasets, dataset_detail_fields)
|
||||||
for item in data:
|
for item in data:
|
||||||
if item['indexing_technique'] == 'high_quality':
|
if item['indexing_technique'] == 'high_quality':
|
||||||
|
|
|
@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
|
||||||
NoFileUploadedError, TooManyFilesError
|
NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from core.model_providers.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.document_fields import document_fields, document_status_fields
|
from fields.document_fields import document_fields, document_status_fields
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
|
@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.segment_fields import segment_fields
|
from fields.segment_fields import segment_fields
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
|
@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
|
@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
|
@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
|
|
|
@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from services.audio_service import AudioService
|
from services.audio_service import AudioService
|
||||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||||
|
@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
|
||||||
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
|
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.conversation_message_task import PubHandler
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='api',
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != 'completion':
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
PubHandler.stop(end_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
args=args,
|
args=args,
|
||||||
from_source='api',
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
PubHandler.stop(end_user, task_id)
|
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
|
|
@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
|
||||||
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.entities.application_entities import InvokeFrom
|
||||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs.helper import uuid_value, TimestampField
|
from libs.helper import uuid_value, TimestampField
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
|
@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args['response_mode'] == 'streaming'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
|
response = CompletionService.generate_more_like_this(
|
||||||
|
app_model=app_model,
|
||||||
|
user=end_user,
|
||||||
|
message_id=message_id,
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
return compact_response(response)
|
return compact_response(response)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except InvokeError as e:
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
raise CompletionRequestError(str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|
101
api/core/agent/agent/agent_llm_callback.py
Normal file
101
api/core/agent/agent/agent_llm_callback.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||||
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentLLMCallback(Callback):
|
||||||
|
|
||||||
|
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
|
||||||
|
self.agent_callback = agent_callback
|
||||||
|
|
||||||
|
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Before invoke callback
|
||||||
|
|
||||||
|
:param llm_instance: LLM instance
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
"""
|
||||||
|
self.agent_callback.on_llm_before_invoke(
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
On new chunk callback
|
||||||
|
|
||||||
|
:param llm_instance: LLM instance
|
||||||
|
:param chunk: chunk
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
After invoke callback
|
||||||
|
|
||||||
|
:param llm_instance: LLM instance
|
||||||
|
:param result: result
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
"""
|
||||||
|
self.agent_callback.on_llm_after_invoke(
|
||||||
|
result=result
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Invoke error callback
|
||||||
|
|
||||||
|
:param llm_instance: LLM instance
|
||||||
|
:param ex: exception
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
"""
|
||||||
|
self.agent_callback.on_llm_error(
|
||||||
|
error=ex
|
||||||
|
)
|
|
@ -1,28 +1,49 @@
|
||||||
from typing import List
|
from typing import List, cast
|
||||||
|
|
||||||
from langchain.schema import BaseMessage
|
from langchain.schema import BaseMessage
|
||||||
|
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
class CalcTokenMixin:
|
class CalcTokenMixin:
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||||
return model_instance.get_num_tokens(to_prompt_messages(messages))
|
|
||||||
|
|
||||||
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
|
||||||
"""
|
"""
|
||||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||||
|
|
||||||
:param llm:
|
:param model_config:
|
||||||
:param messages:
|
:param messages:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
llm_max_tokens = model_instance.model_rules.max_tokens.max
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
completion_max_tokens = model_instance.model_kwargs.max_tokens
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
|
|
||||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if (parameter_rule.name == 'max_tokens'
|
||||||
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||||
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||||
|
|
||||||
|
if model_context_tokens is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = 0
|
||||||
|
|
||||||
|
prompt_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model_config.model,
|
||||||
|
model_config.credentials,
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||||
|
|
||||||
return rest_tokens
|
return rest_tokens
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||||
|
|
||||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||||
|
@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_manager import ModelInstance
|
||||||
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
from core.third_party.langchain.llms.fake import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
"""
|
"""
|
||||||
An Multi Dataset Retrieve Agent driven by Router.
|
An Multi Dataset Retrieve Agent driven by Router.
|
||||||
"""
|
"""
|
||||||
model_instance: BaseLLM
|
model_config: ModelConfigEntity
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
agent_decision.return_values['output'] = ''
|
agent_decision.return_values['output'] = ''
|
||||||
return agent_decision
|
return agent_decision
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
new_exception = self.model_instance.handle_exceptions(e)
|
raise e
|
||||||
raise new_exception
|
|
||||||
|
|
||||||
def real_plan(
|
def real_plan(
|
||||||
self,
|
self,
|
||||||
|
@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||||
prompt = self.prompt.format_prompt(**full_inputs)
|
prompt = self.prompt.format_prompt(**full_inputs)
|
||||||
messages = prompt.to_messages()
|
messages = prompt.to_messages()
|
||||||
prompt_messages = to_prompt_messages(messages)
|
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||||
result = self.model_instance.run(
|
|
||||||
messages=prompt_messages,
|
model_instance = ModelInstance(
|
||||||
functions=self.functions,
|
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||||
|
model=self.model_config.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for function in self.functions:
|
||||||
|
tool = PromptMessageTool(
|
||||||
|
**function
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
model_parameters={
|
||||||
|
'temperature': 0.2,
|
||||||
|
'top_p': 0.3,
|
||||||
|
'max_tokens': 1500
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
ai_message = AIMessage(
|
ai_message = AIMessage(
|
||||||
content=result.content,
|
content=result.message.content or "",
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
'function_call': result.function_call
|
'function_call': {
|
||||||
|
'id': result.message.tool_calls[0].id,
|
||||||
|
**result.message.tool_calls[0].function.dict()
|
||||||
|
} if result.message.tool_calls else None
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
model_instance: BaseLLM,
|
model_config: ModelConfigEntity,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||||
|
@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
model_instance=model_instance,
|
model_config=model_config,
|
||||||
llm=FakeLLM(response=''),
|
llm=FakeLLM(response=''),
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||||
|
|
||||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||||
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||||
|
@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||||
from core.chain.llm_chain import LLMChain
|
from core.chain.llm_chain import LLMChain
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_manager import ModelInstance
|
||||||
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
from core.third_party.langchain.llms.fake import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||||
moving_summary_buffer: str = ""
|
moving_summary_buffer: str = ""
|
||||||
moving_summary_index: int = 0
|
moving_summary_index: int = 0
|
||||||
summary_model_instance: BaseLLM = None
|
summary_model_config: ModelConfigEntity = None
|
||||||
model_instance: BaseLLM
|
model_config: ModelConfigEntity
|
||||||
|
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
model_instance: BaseLLM,
|
model_config: ModelConfigEntity,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
|
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseSingleActionAgent:
|
) -> BaseSingleActionAgent:
|
||||||
prompt = cls.create_prompt(
|
prompt = cls.create_prompt(
|
||||||
|
@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
model_instance=model_instance,
|
model_config=model_config,
|
||||||
llm=FakeLLM(response=''),
|
llm=FakeLLM(response=''),
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
|
agent_llm_callback=agent_llm_callback,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||||
:param query:
|
:param query:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
original_max_tokens = 0
|
||||||
self.model_instance.model_kwargs.max_tokens = 40
|
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||||
|
if (parameter_rule.name == 'max_tokens'
|
||||||
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||||
|
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||||
|
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||||
|
|
||||||
|
self.model_config.parameters['max_tokens'] = 40
|
||||||
|
|
||||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||||
messages = prompt.to_messages()
|
messages = prompt.to_messages()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_messages = to_prompt_messages(messages)
|
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||||
result = self.model_instance.run(
|
model_instance = ModelInstance(
|
||||||
messages=prompt_messages,
|
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||||
functions=self.functions,
|
model=self.model_config.model,
|
||||||
callbacks=None
|
)
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for function in self.functions:
|
||||||
|
tool = PromptMessageTool(
|
||||||
|
**function
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
model_parameters={
|
||||||
|
'temperature': 0.2,
|
||||||
|
'top_p': 0.3,
|
||||||
|
'max_tokens': 1500
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
new_exception = self.model_instance.handle_exceptions(e)
|
raise e
|
||||||
raise new_exception
|
|
||||||
|
|
||||||
function_call = result.function_call
|
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||||
|
|
||||||
self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
return True if result.message.tool_calls else False
|
||||||
|
|
||||||
return True if function_call else False
|
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
|
@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||||
prompt = self.prompt.format_prompt(**full_inputs)
|
prompt = self.prompt.format_prompt(**full_inputs)
|
||||||
messages = prompt.to_messages()
|
messages = prompt.to_messages()
|
||||||
|
|
||||||
|
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||||
|
|
||||||
# summarize messages if rest_tokens < 0
|
# summarize messages if rest_tokens < 0
|
||||||
try:
|
try:
|
||||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||||
except ExceededLLMTokensLimitError as e:
|
except ExceededLLMTokensLimitError as e:
|
||||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||||
|
|
||||||
prompt_messages = to_prompt_messages(messages)
|
model_instance = ModelInstance(
|
||||||
result = self.model_instance.run(
|
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||||
messages=prompt_messages,
|
model=self.model_config.model,
|
||||||
functions=self.functions,
|
)
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for function in self.functions:
|
||||||
|
tool = PromptMessageTool(
|
||||||
|
**function
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||||
|
model_parameters={
|
||||||
|
'temperature': 0.2,
|
||||||
|
'top_p': 0.3,
|
||||||
|
'max_tokens': 1500
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
ai_message = AIMessage(
|
ai_message = AIMessage(
|
||||||
content=result.content,
|
content=result.message.content or "",
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
'function_call': result.function_call
|
'function_call': {
|
||||||
|
'id': result.message.tool_calls[0].id,
|
||||||
|
**result.message.tool_calls[0].function.dict()
|
||||||
|
} if result.message.tool_calls else None
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
agent_decision = _parse_ai_message(ai_message)
|
agent_decision = _parse_ai_message(ai_message)
|
||||||
|
@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||||
|
|
||||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
rest_tokens = self.get_message_rest_tokens(
|
||||||
|
self.model_config,
|
||||||
|
messages,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||||
if rest_tokens >= 0:
|
if rest_tokens >= 0:
|
||||||
return messages
|
return messages
|
||||||
|
@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||||
ai_prefix="AI",
|
ai_prefix="AI",
|
||||||
)
|
)
|
||||||
|
|
||||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
if model_instance.model_provider.provider_name == 'azure_openai':
|
if model_config.provider == 'azure_openai':
|
||||||
model = model_instance.base_model_name
|
model = model_config.model
|
||||||
model = model.replace("gpt-35", "gpt-3.5")
|
model = model.replace("gpt-35", "gpt-3.5")
|
||||||
else:
|
else:
|
||||||
model = model_instance.base_model_name
|
model = model_config.credentials.get("base_model_name")
|
||||||
|
|
||||||
tiktoken_ = _import_tiktoken()
|
tiktoken_ = _import_tiktoken()
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,158 +0,0 @@
|
||||||
import json
|
|
||||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
|
||||||
|
|
||||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
|
||||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
|
||||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from pydantic import root_validator
|
|
||||||
|
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
|
||||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|
||||||
"""
|
|
||||||
An Multi Dataset Retrieve Agent driven by Router.
|
|
||||||
"""
|
|
||||||
model_instance: BaseLLM
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_llm(cls, values: dict) -> dict:
|
|
||||||
return values
|
|
||||||
|
|
||||||
def should_use_agent(self, query: str):
|
|
||||||
"""
|
|
||||||
return should use agent
|
|
||||||
|
|
||||||
:param query:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
if len(self.tools) == 0:
|
|
||||||
return AgentFinish(return_values={"output": ''}, log='')
|
|
||||||
elif len(self.tools) == 1:
|
|
||||||
tool = next(iter(self.tools))
|
|
||||||
tool = cast(DatasetRetrieverTool, tool)
|
|
||||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
|
||||||
# output = ''
|
|
||||||
# rst_json = json.loads(rst)
|
|
||||||
# for item in rst_json:
|
|
||||||
# output += f'{item["content"]}\n'
|
|
||||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
|
||||||
|
|
||||||
if intermediate_steps:
|
|
||||||
_, observation = intermediate_steps[-1]
|
|
||||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
|
||||||
if isinstance(agent_decision, AgentAction):
|
|
||||||
tool_inputs = agent_decision.tool_input
|
|
||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
|
||||||
tool_inputs['query'] = kwargs['input']
|
|
||||||
agent_decision.tool_input = tool_inputs
|
|
||||||
else:
|
|
||||||
agent_decision.return_values['output'] = ''
|
|
||||||
return agent_decision
|
|
||||||
except Exception as e:
|
|
||||||
new_exception = self.model_instance.handle_exceptions(e)
|
|
||||||
raise new_exception
|
|
||||||
|
|
||||||
def real_plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
|
||||||
selected_inputs = {
|
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
|
||||||
}
|
|
||||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
|
||||||
prompt = self.prompt.format_prompt(**full_inputs)
|
|
||||||
messages = prompt.to_messages()
|
|
||||||
prompt_messages = to_prompt_messages(messages)
|
|
||||||
result = self.model_instance.run(
|
|
||||||
messages=prompt_messages,
|
|
||||||
functions=self.functions,
|
|
||||||
)
|
|
||||||
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content=result.content,
|
|
||||||
additional_kwargs={
|
|
||||||
'function_call': result.function_call
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_decision = _parse_ai_message(ai_message)
|
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
async def aplan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm_and_tools(
|
|
||||||
cls,
|
|
||||||
model_instance: BaseLLM,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
|
||||||
content="You are a helpful AI assistant."
|
|
||||||
),
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> BaseSingleActionAgent:
|
|
||||||
prompt = cls.create_prompt(
|
|
||||||
extra_prompt_messages=extra_prompt_messages,
|
|
||||||
system_message=system_message,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
model_instance=model_instance,
|
|
||||||
llm=FakeLLM(response=''),
|
|
||||||
prompt=prompt,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
|
@ -12,9 +12,7 @@ from langchain.tools import BaseTool
|
||||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||||
|
|
||||||
from core.chain.llm_chain import LLMChain
|
from core.chain.llm_chain import LLMChain
|
||||||
from core.model_providers.models.entity.model_params import ModelMode
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||||
|
@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||||
try:
|
try:
|
||||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
raise e
|
||||||
raise new_exception
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_decision = self.output_parser.parse(full_output)
|
agent_decision = self.output_parser.parse(full_output)
|
||||||
|
@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||||
except OutputParserException:
|
except OutputParserException:
|
||||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||||
"I don't know how to respond to that."}, "")
|
"I don't know how to respond to that."}, "")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
cls,
|
cls,
|
||||||
|
@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
|
||||||
return PromptTemplate(template=template, input_variables=input_variables)
|
return PromptTemplate(template=template, input_variables=input_variables)
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||||
) -> str:
|
) -> str:
|
||||||
agent_scratchpad = ""
|
agent_scratchpad = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
|
@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
|
||||||
raise ValueError("agent_scratchpad should be of type string.")
|
raise ValueError("agent_scratchpad should be of type string.")
|
||||||
if agent_scratchpad:
|
if agent_scratchpad:
|
||||||
llm_chain = cast(LLMChain, self.llm_chain)
|
llm_chain = cast(LLMChain, self.llm_chain)
|
||||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
if llm_chain.model_config.mode == "chat":
|
||||||
return (
|
return (
|
||||||
f"This was your previous work "
|
f"This was your previous work "
|
||||||
f"(but I haven't seen any of it! I only see what "
|
f"(but I haven't seen any of it! I only see what "
|
||||||
|
@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
model_instance: BaseLLM,
|
model_config: ModelConfigEntity,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
output_parser: Optional[AgentOutputParser] = None,
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
|
@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
if model_instance.model_mode == ModelMode.CHAT:
|
if model_config.mode == "chat":
|
||||||
prompt = cls.create_prompt(
|
prompt = cls.create_prompt(
|
||||||
tools,
|
tools,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
|
||||||
format_instructions=format_instructions,
|
format_instructions=format_instructions,
|
||||||
input_variables=input_variables
|
input_variables=input_variables
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
model_instance=model_instance,
|
model_config=model_config,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
|
parameters={
|
||||||
|
'temperature': 0.2,
|
||||||
|
'top_p': 0.3,
|
||||||
|
'max_tokens': 1500
|
||||||
|
}
|
||||||
)
|
)
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
_output_parser = output_parser
|
_output_parser = output_parser
|
||||||
|
|
|
@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||||
|
|
||||||
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||||
from core.chain.llm_chain import LLMChain
|
from core.chain.llm_chain import LLMChain
|
||||||
from core.model_providers.models.entity.model_params import ModelMode
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||||
|
@ -54,7 +55,7 @@ Action:
|
||||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||||
moving_summary_buffer: str = ""
|
moving_summary_buffer: str = ""
|
||||||
moving_summary_index: int = 0
|
moving_summary_index: int = 0
|
||||||
summary_model_instance: BaseLLM = None
|
summary_model_config: ModelConfigEntity = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
intermediate_steps: Steps the LLM has taken to date,
|
intermediate_steps: Steps the LLM has taken to date,
|
||||||
along with observations
|
along with observatons
|
||||||
callbacks: Callbacks to run.
|
callbacks: Callbacks to run.
|
||||||
**kwargs: User inputs.
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||||
if prompts:
|
if prompts:
|
||||||
messages = prompts[0].to_messages()
|
messages = prompts[0].to_messages()
|
||||||
|
|
||||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
|
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||||
|
|
||||||
|
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||||
if rest_tokens < 0:
|
if rest_tokens < 0:
|
||||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
raise e
|
||||||
raise new_exception
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_decision = self.output_parser.parse(full_output)
|
agent_decision = self.output_parser.parse(full_output)
|
||||||
|
@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||||
"I don't know how to respond to that."}, "")
|
"I don't know how to respond to that."}, "")
|
||||||
|
|
||||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||||
if len(intermediate_steps) >= 2 and self.summary_model_instance:
|
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||||
should_summary_messages = [AIMessage(content=observation)
|
should_summary_messages = [AIMessage(content=observation)
|
||||||
for _, observation in should_summary_intermediate_steps]
|
for _, observation in should_summary_intermediate_steps]
|
||||||
|
@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||||
ai_prefix="AI",
|
ai_prefix="AI",
|
||||||
)
|
)
|
||||||
|
|
||||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
|
||||||
raise ValueError("agent_scratchpad should be of type string.")
|
raise ValueError("agent_scratchpad should be of type string.")
|
||||||
if agent_scratchpad:
|
if agent_scratchpad:
|
||||||
llm_chain = cast(LLMChain, self.llm_chain)
|
llm_chain = cast(LLMChain, self.llm_chain)
|
||||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
if llm_chain.model_config.mode == "chat":
|
||||||
return (
|
return (
|
||||||
f"This was your previous work "
|
f"This was your previous work "
|
||||||
f"(but I haven't seen any of it! I only see what "
|
f"(but I haven't seen any of it! I only see what "
|
||||||
|
@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
model_instance: BaseLLM,
|
model_config: ModelConfigEntity,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
output_parser: Optional[AgentOutputParser] = None,
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
|
@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
|
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
if model_instance.model_mode == ModelMode.CHAT:
|
if model_config.mode == "chat":
|
||||||
prompt = cls.create_prompt(
|
prompt = cls.create_prompt(
|
||||||
tools,
|
tools,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
|
||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
)
|
)
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
model_instance=model_instance,
|
model_config=model_config,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
|
agent_llm_callback=agent_llm_callback,
|
||||||
|
parameters={
|
||||||
|
'temperature': 0.2,
|
||||||
|
'top_p': 0.3,
|
||||||
|
'max_tokens': 1500
|
||||||
|
}
|
||||||
)
|
)
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
_output_parser = output_parser
|
_output_parser = output_parser
|
||||||
|
|
|
@ -4,10 +4,10 @@ from typing import Union, Optional
|
||||||
|
|
||||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||||
|
@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
|
||||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||||
|
|
||||||
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
|
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||||
from core.helper import moderation
|
from core.helper import moderation
|
||||||
from core.model_providers.error import LLMError
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
|
@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
|
||||||
|
|
||||||
class AgentConfiguration(BaseModel):
|
class AgentConfiguration(BaseModel):
|
||||||
strategy: PlanningStrategy
|
strategy: PlanningStrategy
|
||||||
model_instance: BaseLLM
|
model_config: ModelConfigEntity
|
||||||
tools: list[BaseTool]
|
tools: list[BaseTool]
|
||||||
summary_model_instance: BaseLLM = None
|
summary_model_config: Optional[ModelConfigEntity] = None
|
||||||
memory: Optional[BaseChatMemory] = None
|
memory: Optional[TokenBufferMemory] = None
|
||||||
callbacks: Callbacks = None
|
callbacks: Callbacks = None
|
||||||
max_iterations: int = 6
|
max_iterations: int = 6
|
||||||
max_execution_time: Optional[float] = None
|
max_execution_time: Optional[float] = None
|
||||||
early_stopping_method: str = "generate"
|
early_stopping_method: str = "generate"
|
||||||
|
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -62,34 +65,42 @@ class AgentExecutor:
|
||||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||||
model_instance=self.configuration.model_instance,
|
model_config=self.configuration.model_config,
|
||||||
tools=self.configuration.tools,
|
tools=self.configuration.tools,
|
||||||
output_parser=StructuredChatOutputParser(),
|
output_parser=StructuredChatOutputParser(),
|
||||||
summary_model_instance=self.configuration.summary_model_instance
|
summary_model_config=self.configuration.summary_model_config
|
||||||
if self.configuration.summary_model_instance else None,
|
if self.configuration.summary_model_config else None,
|
||||||
|
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||||
model_instance=self.configuration.model_instance,
|
model_config=self.configuration.model_config,
|
||||||
tools=self.configuration.tools,
|
tools=self.configuration.tools,
|
||||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||||
summary_model_instance=self.configuration.summary_model_instance
|
if self.configuration.memory else None, # used for read chat histories memory
|
||||||
if self.configuration.summary_model_instance else None,
|
summary_model_config=self.configuration.summary_model_config
|
||||||
|
if self.configuration.summary_model_config else None,
|
||||||
|
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
self.configuration.tools = [t for t in self.configuration.tools
|
||||||
|
if isinstance(t, DatasetRetrieverTool)
|
||||||
|
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||||
model_instance=self.configuration.model_instance,
|
model_config=self.configuration.model_config,
|
||||||
tools=self.configuration.tools,
|
tools=self.configuration.tools,
|
||||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||||
|
if self.configuration.memory else None,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
||||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
self.configuration.tools = [t for t in self.configuration.tools
|
||||||
|
if isinstance(t, DatasetRetrieverTool)
|
||||||
|
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||||
model_instance=self.configuration.model_instance,
|
model_config=self.configuration.model_config,
|
||||||
tools=self.configuration.tools,
|
tools=self.configuration.tools,
|
||||||
output_parser=StructuredChatOutputParser(),
|
output_parser=StructuredChatOutputParser(),
|
||||||
verbose=True
|
verbose=True
|
||||||
|
@ -104,11 +115,11 @@ class AgentExecutor:
|
||||||
|
|
||||||
def run(self, query: str) -> AgentExecuteResult:
|
def run(self, query: str) -> AgentExecuteResult:
|
||||||
moderation_result = moderation.check_moderation(
|
moderation_result = moderation.check_moderation(
|
||||||
self.configuration.model_instance.model_provider,
|
self.configuration.model_config,
|
||||||
query
|
query
|
||||||
)
|
)
|
||||||
|
|
||||||
if not moderation_result:
|
if moderation_result:
|
||||||
return AgentExecuteResult(
|
return AgentExecuteResult(
|
||||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||||
strategy=self.configuration.strategy,
|
strategy=self.configuration.strategy,
|
||||||
|
@ -118,7 +129,6 @@ class AgentExecutor:
|
||||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
tools=self.configuration.tools,
|
tools=self.configuration.tools,
|
||||||
memory=self.configuration.memory,
|
|
||||||
max_iterations=self.configuration.max_iterations,
|
max_iterations=self.configuration.max_iterations,
|
||||||
max_execution_time=self.configuration.max_execution_time,
|
max_execution_time=self.configuration.max_execution_time,
|
||||||
early_stopping_method=self.configuration.early_stopping_method,
|
early_stopping_method=self.configuration.early_stopping_method,
|
||||||
|
@ -126,8 +136,8 @@ class AgentExecutor:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = agent_executor.run(query)
|
output = agent_executor.run(input=query)
|
||||||
except LLMError as ex:
|
except InvokeError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.exception("agent_executor run failed")
|
logging.exception("agent_executor run failed")
|
||||||
|
|
251
api/core/app_runner/agent_app_runner.py
Normal file
251
api/core/app_runner/agent_app_runner.py
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||||
|
from core.app_runner.app_runner import AppRunner
|
||||||
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
|
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.features.agent_runner import AgentRunnerFeature
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentApplicationRunner(AppRunner):
|
||||||
|
"""
|
||||||
|
Agent Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Run agent application
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||||
|
if not app_record:
|
||||||
|
raise ValueError(f"App not found")
|
||||||
|
|
||||||
|
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||||
|
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
query = application_generate_entity.query
|
||||||
|
files = application_generate_entity.files
|
||||||
|
|
||||||
|
# Pre-calculate the number of tokens of the prompt messages,
|
||||||
|
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||||
|
# If the rest number of tokens is not enough, raise exception.
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# Not Include: memory, external data, dataset context
|
||||||
|
self.get_pre_calculate_rest_tokens(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = None
|
||||||
|
if application_generate_entity.conversation_id:
|
||||||
|
# get memory of conversation (read-only)
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
|
model=app_orchestration_config.model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = TokenBufferMemory(
|
||||||
|
conversation=conversation,
|
||||||
|
model_instance=model_instance
|
||||||
|
)
|
||||||
|
|
||||||
|
# reorganize all inputs and template to prompt messages
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# memory(optional)
|
||||||
|
prompt_messages, stop = self.originze_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query,
|
||||||
|
context=None,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create MessageChain
|
||||||
|
message_chain = self._init_message_chain(
|
||||||
|
message=message,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
# add agent callback to record agent thoughts
|
||||||
|
agent_callback = AgentLoopGatherCallbackHandler(
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
message=message,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
message_chain=message_chain
|
||||||
|
)
|
||||||
|
|
||||||
|
# init LLM Callback
|
||||||
|
agent_llm_callback = AgentLLMCallback(
|
||||||
|
agent_callback=agent_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_runner = AgentRunnerFeature(
|
||||||
|
tenant_id=application_generate_entity.tenant_id,
|
||||||
|
app_orchestration_config=app_orchestration_config,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
config=app_orchestration_config.agent,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
message=message,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
agent_llm_callback=agent_llm_callback,
|
||||||
|
callback=agent_callback,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# agent run
|
||||||
|
result = agent_runner.run(
|
||||||
|
query=query,
|
||||||
|
invoke_from=application_generate_entity.invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
self._save_message_chain(
|
||||||
|
message_chain=message_chain,
|
||||||
|
output_text=result
|
||||||
|
)
|
||||||
|
|
||||||
|
if (result
|
||||||
|
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||||
|
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||||
|
):
|
||||||
|
# Direct output if agent result exists and has pre prompt
|
||||||
|
self.direct_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_orchestration_config=app_orchestration_config,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
stream=application_generate_entity.stream,
|
||||||
|
text=result,
|
||||||
|
usage=self._get_usage_of_all_agent_thoughts(
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# As normal LLM run, agent result as context
|
||||||
|
context = result
|
||||||
|
|
||||||
|
# reorganize all inputs and template to prompt messages
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# memory(optional), external data, dataset context(optional)
|
||||||
|
prompt_messages, stop = self.originze_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||||
|
self.recale_llm_max_tokens(
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke model
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
|
model=app_orchestration_config.model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
invoke_result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=app_orchestration_config.model_config.parameters,
|
||||||
|
stop=stop,
|
||||||
|
stream=application_generate_entity.stream,
|
||||||
|
user=application_generate_entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle invoke result
|
||||||
|
self._handle_invoke_result(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||||
|
"""
|
||||||
|
Init MessageChain
|
||||||
|
:param message: message
|
||||||
|
:param query: query
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
message_chain = MessageChain(
|
||||||
|
message_id=message.id,
|
||||||
|
type="AgentExecutor",
|
||||||
|
input=json.dumps({
|
||||||
|
"input": query
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(message_chain)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return message_chain
|
||||||
|
|
||||||
|
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||||
|
"""
|
||||||
|
Save MessageChain
|
||||||
|
:param message_chain: message chain
|
||||||
|
:param output_text: output text
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
message_chain.output = json.dumps({
|
||||||
|
"output": output_text
|
||||||
|
})
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||||
|
message: Message) -> LLMUsage:
|
||||||
|
"""
|
||||||
|
Get usage of all agent thoughts
|
||||||
|
:param model_config: model config
|
||||||
|
:param message: message
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||||
|
.filter(MessageAgentThought.message_id == message.id).all())
|
||||||
|
|
||||||
|
all_message_tokens = 0
|
||||||
|
all_answer_tokens = 0
|
||||||
|
for agent_thought in agent_thoughts:
|
||||||
|
all_message_tokens += agent_thought.message_tokens
|
||||||
|
all_answer_tokens += agent_thought.answer_tokens
|
||||||
|
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
return model_type_instance._calc_response_usage(
|
||||||
|
model_config.model,
|
||||||
|
model_config.credentials,
|
||||||
|
all_message_tokens,
|
||||||
|
all_answer_tokens
|
||||||
|
)
|
267
api/core/app_runner/app_runner.py
Normal file
267
api/core/app_runner/app_runner.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
import time
|
||||||
|
from typing import cast, Optional, List, Tuple, Generator, Union
|
||||||
|
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
||||||
|
from core.file.file_obj import FileObj
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
|
from models.model import App
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunner:
|
||||||
|
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||||
|
model_config: ModelConfigEntity,
|
||||||
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
|
inputs: dict[str, str],
|
||||||
|
files: list[FileObj],
|
||||||
|
query: Optional[str] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get pre calculate rest tokens
|
||||||
|
:param app_record: app record
|
||||||
|
:param model_config: model config entity
|
||||||
|
:param prompt_template_entity: prompt template entity
|
||||||
|
:param inputs: inputs
|
||||||
|
:param files: files
|
||||||
|
:param query: query
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if (parameter_rule.name == 'max_tokens'
|
||||||
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||||
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||||
|
|
||||||
|
if model_context_tokens is None:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = 0
|
||||||
|
|
||||||
|
# get prompt messages without memory and context
|
||||||
|
prompt_messages, stop = self.originze_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=model_config,
|
||||||
|
prompt_template_entity=prompt_template_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model_config.model,
|
||||||
|
model_config.credentials,
|
||||||
|
prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||||
|
if rest_tokens < 0:
|
||||||
|
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||||
|
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
|
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||||
|
prompt_messages: List[PromptMessage]):
|
||||||
|
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if (parameter_rule.name == 'max_tokens'
|
||||||
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||||
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||||
|
|
||||||
|
if model_context_tokens is None:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = 0
|
||||||
|
|
||||||
|
prompt_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model_config.model,
|
||||||
|
model_config.credentials,
|
||||||
|
prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_tokens + max_tokens > model_context_tokens:
|
||||||
|
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||||
|
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if (parameter_rule.name == 'max_tokens'
|
||||||
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||||
|
model_config.parameters[parameter_rule.name] = max_tokens
|
||||||
|
|
||||||
|
def originze_prompt_messages(self, app_record: App,
|
||||||
|
model_config: ModelConfigEntity,
|
||||||
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
|
inputs: dict[str, str],
|
||||||
|
files: list[FileObj],
|
||||||
|
query: Optional[str] = None,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
memory: Optional[TokenBufferMemory] = None) \
|
||||||
|
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||||
|
"""
|
||||||
|
Organize prompt messages
|
||||||
|
:param context:
|
||||||
|
:param app_record: app record
|
||||||
|
:param model_config: model config entity
|
||||||
|
:param prompt_template_entity: prompt template entity
|
||||||
|
:param inputs: inputs
|
||||||
|
:param files: files
|
||||||
|
:param query: query
|
||||||
|
:param memory: memory
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompt_transform = PromptTransform()
|
||||||
|
|
||||||
|
# get prompt without memory and context
|
||||||
|
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||||
|
prompt_messages, stop = prompt_transform.get_prompt(
|
||||||
|
app_mode=app_record.mode,
|
||||||
|
prompt_template_entity=prompt_template_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query if query else '',
|
||||||
|
files=files,
|
||||||
|
context=context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||||
|
app_mode=app_record.mode,
|
||||||
|
prompt_template_entity=prompt_template_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
files=files,
|
||||||
|
context=context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config
|
||||||
|
)
|
||||||
|
stop = model_config.stop
|
||||||
|
|
||||||
|
return prompt_messages, stop
|
||||||
|
|
||||||
|
def direct_output(self, queue_manager: ApplicationQueueManager,
|
||||||
|
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||||
|
prompt_messages: list,
|
||||||
|
text: str,
|
||||||
|
stream: bool,
|
||||||
|
usage: Optional[LLMUsage] = None) -> None:
|
||||||
|
"""
|
||||||
|
Direct output
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param app_orchestration_config: app orchestration config
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param text: text
|
||||||
|
:param stream: stream
|
||||||
|
:param usage: usage
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if stream:
|
||||||
|
index = 0
|
||||||
|
for token in text:
|
||||||
|
queue_manager.publish_chunk_message(LLMResultChunk(
|
||||||
|
model=app_orchestration_config.model_config.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=AssistantPromptMessage(content=token)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
index += 1
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
queue_manager.publish_message_end(
|
||||||
|
llm_result=LLMResult(
|
||||||
|
model=app_orchestration_config.model_config.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=AssistantPromptMessage(content=text),
|
||||||
|
usage=usage if usage else LLMUsage.empty_usage()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
stream: bool) -> None:
|
||||||
|
"""
|
||||||
|
Handle invoke result
|
||||||
|
:param invoke_result: invoke result
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param stream: stream
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not stream:
|
||||||
|
self._handle_invoke_result_direct(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._handle_invoke_result_stream(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||||
|
queue_manager: ApplicationQueueManager) -> None:
|
||||||
|
"""
|
||||||
|
Handle invoke result direct
|
||||||
|
:param invoke_result: invoke result
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
queue_manager.publish_message_end(
|
||||||
|
llm_result=invoke_result
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||||
|
queue_manager: ApplicationQueueManager) -> None:
|
||||||
|
"""
|
||||||
|
Handle invoke result
|
||||||
|
:param invoke_result: invoke result
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
model = None
|
||||||
|
prompt_messages = []
|
||||||
|
text = ''
|
||||||
|
usage = None
|
||||||
|
for result in invoke_result:
|
||||||
|
queue_manager.publish_chunk_message(result)
|
||||||
|
|
||||||
|
text += result.delta.message.content
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = result.model
|
||||||
|
|
||||||
|
if not prompt_messages:
|
||||||
|
prompt_messages = result.prompt_messages
|
||||||
|
|
||||||
|
if not usage and result.delta.usage:
|
||||||
|
usage = result.delta.usage
|
||||||
|
|
||||||
|
llm_result = LLMResult(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=AssistantPromptMessage(content=text),
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
queue_manager.publish_message_end(
|
||||||
|
llm_result=llm_result
|
||||||
|
)
|
363
api/core/app_runner/basic_app_runner.py
Normal file
363
api/core/app_runner/basic_app_runner.py
Normal file
|
@ -0,0 +1,363 @@
|
||||||
|
import logging
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
from core.app_runner.app_runner import AppRunner
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||||
|
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.features.annotation_reply import AnnotationReplyFeature
|
||||||
|
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||||
|
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||||
|
from core.features.hosting_moderation import HostingModerationFeature
|
||||||
|
from core.features.moderation import ModerationFeature
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
from core.moderation.base import ModerationException
|
||||||
|
from core.prompt.prompt_transform import AppMode
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Conversation, Message, App, MessageAnnotation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicApplicationRunner(AppRunner):
|
||||||
|
"""
|
||||||
|
Basic Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||||
|
if not app_record:
|
||||||
|
raise ValueError(f"App not found")
|
||||||
|
|
||||||
|
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||||
|
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
query = application_generate_entity.query
|
||||||
|
files = application_generate_entity.files
|
||||||
|
|
||||||
|
# Pre-calculate the number of tokens of the prompt messages,
|
||||||
|
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||||
|
# If the rest number of tokens is not enough, raise exception.
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# Not Include: memory, external data, dataset context
|
||||||
|
self.get_pre_calculate_rest_tokens(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = None
|
||||||
|
if application_generate_entity.conversation_id:
|
||||||
|
# get memory of conversation (read-only)
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
|
model=app_orchestration_config.model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = TokenBufferMemory(
|
||||||
|
conversation=conversation,
|
||||||
|
model_instance=model_instance
|
||||||
|
)
|
||||||
|
|
||||||
|
# organize all inputs and template to prompt messages
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# memory(optional)
|
||||||
|
prompt_messages, stop = self.originze_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# moderation
|
||||||
|
try:
|
||||||
|
# process sensitive_word_avoidance
|
||||||
|
_, inputs, query = self.moderation_for_inputs(
|
||||||
|
app_id=app_record.id,
|
||||||
|
tenant_id=application_generate_entity.tenant_id,
|
||||||
|
app_orchestration_config_entity=app_orchestration_config,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
except ModerationException as e:
|
||||||
|
self.direct_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_orchestration_config=app_orchestration_config,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
text=str(e),
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if query:
|
||||||
|
# annotation reply
|
||||||
|
annotation_reply = self.query_app_annotations_to_reply(
|
||||||
|
app_record=app_record,
|
||||||
|
message=message,
|
||||||
|
query=query,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotation_reply:
|
||||||
|
queue_manager.publish_annotation_reply(
|
||||||
|
message_annotation_id=annotation_reply.id
|
||||||
|
)
|
||||||
|
self.direct_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_orchestration_config=app_orchestration_config,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
text=annotation_reply.content,
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# fill in variable inputs from external data tools if exists
|
||||||
|
external_data_tools = app_orchestration_config.external_data_variables
|
||||||
|
if external_data_tools:
|
||||||
|
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
app_id=app_record.id,
|
||||||
|
external_data_tools=external_data_tools,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
# get context from datasets
|
||||||
|
context = None
|
||||||
|
if app_orchestration_config.dataset:
|
||||||
|
context = self.retrieve_dataset_context(
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
app_record=app_record,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
show_retrieve_source=app_orchestration_config.show_retrieve_source,
|
||||||
|
dataset_config=app_orchestration_config.dataset,
|
||||||
|
message=message,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# reorganize all inputs and template to prompt messages
|
||||||
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
|
# memory(optional), external data, dataset context(optional)
|
||||||
|
prompt_messages, stop = self.originze_prompt_messages(
|
||||||
|
app_record=app_record,
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||||
|
inputs=inputs,
|
||||||
|
files=files,
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# check hosting moderation
|
||||||
|
hosting_moderation_result = self.check_hosting_moderation(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if hosting_moderation_result:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||||
|
self.recale_llm_max_tokens(
|
||||||
|
model_config=app_orchestration_config.model_config,
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke model
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
|
model=app_orchestration_config.model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
invoke_result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=app_orchestration_config.model_config.parameters,
|
||||||
|
stop=stop,
|
||||||
|
stream=application_generate_entity.stream,
|
||||||
|
user=application_generate_entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle invoke result
|
||||||
|
self._handle_invoke_result(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def moderation_for_inputs(self, app_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||||
|
inputs: dict,
|
||||||
|
query: str) -> Tuple[bool, dict, str]:
|
||||||
|
"""
|
||||||
|
Process sensitive_word_avoidance.
|
||||||
|
:param app_id: app id
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param app_orchestration_config_entity: app orchestration config entity
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
moderation_feature = ModerationFeature()
|
||||||
|
return moderation_feature.check(
|
||||||
|
app_id=app_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
def query_app_annotations_to_reply(self, app_record: App,
|
||||||
|
message: Message,
|
||||||
|
query: str,
|
||||||
|
user_id: str,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||||
|
"""
|
||||||
|
Query app annotations to reply
|
||||||
|
:param app_record: app record
|
||||||
|
:param message: message
|
||||||
|
:param query: query
|
||||||
|
:param user_id: user id
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
annotation_reply_feature = AnnotationReplyFeature()
|
||||||
|
return annotation_reply_feature.query(
|
||||||
|
app_record=app_record,
|
||||||
|
message=message,
|
||||||
|
query=query,
|
||||||
|
user_id=user_id,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
external_data_tools: list[ExternalDataVariableEntity],
|
||||||
|
inputs: dict,
|
||||||
|
query: str) -> dict:
|
||||||
|
"""
|
||||||
|
Fill in variable inputs from external data tools if exists.
|
||||||
|
|
||||||
|
:param tenant_id: workspace id
|
||||||
|
:param app_id: app id
|
||||||
|
:param external_data_tools: external data tools configs
|
||||||
|
:param inputs: the inputs
|
||||||
|
:param query: the query
|
||||||
|
:return: the filled inputs
|
||||||
|
"""
|
||||||
|
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||||
|
return external_data_fetch_feature.fetch(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
external_data_tools=external_data_tools,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve_dataset_context(self, tenant_id: str,
|
||||||
|
app_record: App,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
model_config: ModelConfigEntity,
|
||||||
|
dataset_config: DatasetEntity,
|
||||||
|
show_retrieve_source: bool,
|
||||||
|
message: Message,
|
||||||
|
inputs: dict,
|
||||||
|
query: str,
|
||||||
|
user_id: str,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve dataset context
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param app_record: app record
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param model_config: model config
|
||||||
|
:param dataset_config: dataset config
|
||||||
|
:param show_retrieve_source: show retrieve source
|
||||||
|
:param message: message
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:param user_id: user id
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param memory: memory
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(
|
||||||
|
queue_manager,
|
||||||
|
app_record.id,
|
||||||
|
message.id,
|
||||||
|
user_id,
|
||||||
|
invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
|
||||||
|
and dataset_config.retrieve_config.query_variable):
|
||||||
|
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||||
|
|
||||||
|
dataset_retrieval = DatasetRetrievalFeature()
|
||||||
|
return dataset_retrieval.retrieve(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_config=model_config,
|
||||||
|
config=dataset_config,
|
||||||
|
query=query,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
show_retrieve_source=show_retrieve_source,
|
||||||
|
hit_callback=hit_callback,
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
prompt_messages: list[PromptMessage]) -> bool:
|
||||||
|
"""
|
||||||
|
Check hosting moderation
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
hosting_moderation_feature = HostingModerationFeature()
|
||||||
|
moderation_result = hosting_moderation_feature.check(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
prompt_messages=prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if moderation_result:
|
||||||
|
self.direct_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
text="I apologize for any confusion, " \
|
||||||
|
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||||
|
stream=application_generate_entity.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
return moderation_result
|
483
api/core/app_runner/generate_task_pipeline.py
Normal file
483
api/core/app_runner/generate_task_pipeline.py
Normal file
|
@ -0,0 +1,483 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Union, Generator, cast, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
||||||
|
from core.entities.application_entities import ApplicationGenerateEntity
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
||||||
|
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
||||||
|
AnnotationReplyEvent
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
|
||||||
|
TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.prompt_template import PromptTemplateParser
|
||||||
|
from events.message_event import message_was_created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Message, Conversation, MessageAgentThought
|
||||||
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskState(BaseModel):
|
||||||
|
"""
|
||||||
|
TaskState entity
|
||||||
|
"""
|
||||||
|
llm_result: LLMResult
|
||||||
|
metadata: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateTaskPipeline:
|
||||||
|
"""
|
||||||
|
GenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Initialize GenerateTaskPipeline.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
"""
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._queue_manager = queue_manager
|
||||||
|
self._conversation = conversation
|
||||||
|
self._message = message
|
||||||
|
self._task_state = TaskState(
|
||||||
|
llm_result=LLMResult(
|
||||||
|
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
|
||||||
|
prompt_messages=[],
|
||||||
|
message=AssistantPromptMessage(content=""),
|
||||||
|
usage=LLMUsage.empty_usage()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._start_at = time.perf_counter()
|
||||||
|
self._output_moderation_handler = self._init_output_moderation()
|
||||||
|
|
||||||
|
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||||
|
"""
|
||||||
|
Process generate task pipeline.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if stream:
|
||||||
|
return self._process_stream_response()
|
||||||
|
else:
|
||||||
|
return self._process_blocking_response()
|
||||||
|
|
||||||
|
def _process_blocking_response(self) -> dict:
|
||||||
|
"""
|
||||||
|
Process blocking response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for queue_message in self._queue_manager.listen():
|
||||||
|
event = queue_message.event
|
||||||
|
|
||||||
|
if isinstance(event, QueueErrorEvent):
|
||||||
|
raise self._handle_error(event)
|
||||||
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
|
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||||
|
elif isinstance(event, AnnotationReplyEvent):
|
||||||
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
|
if annotation:
|
||||||
|
account = annotation.account
|
||||||
|
self._task_state.metadata['annotation_reply'] = {
|
||||||
|
'id': annotation.id,
|
||||||
|
'account': {
|
||||||
|
'id': annotation.account_id,
|
||||||
|
'name': account.name if account else 'Dify user'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self._task_state.llm_result.message.content = annotation.content
|
||||||
|
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||||
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
|
self._task_state.llm_result = event.llm_result
|
||||||
|
else:
|
||||||
|
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||||
|
model = model_config.model
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = 0
|
||||||
|
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||||
|
prompt_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model,
|
||||||
|
model_config.credentials,
|
||||||
|
self._task_state.llm_result.prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_tokens = 0
|
||||||
|
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||||
|
completion_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model,
|
||||||
|
model_config.credentials,
|
||||||
|
[self._task_state.llm_result.message]
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials = model_config.credentials
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# response moderation
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
self._output_moderation_handler.stop_thread()
|
||||||
|
|
||||||
|
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||||
|
completion=self._task_state.llm_result.message.content,
|
||||||
|
public_event=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save message
|
||||||
|
self._save_message(event.llm_result)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'event': 'message',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'id': self._message.id,
|
||||||
|
'mode': self._conversation.mode,
|
||||||
|
'answer': event.llm_result.message.content,
|
||||||
|
'metadata': {},
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._conversation.mode == 'chat':
|
||||||
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
|
if self._task_state.metadata:
|
||||||
|
response['metadata'] = self._task_state.metadata
|
||||||
|
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _process_stream_response(self) -> Generator:
|
||||||
|
"""
|
||||||
|
Process stream response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for message in self._queue_manager.listen():
|
||||||
|
event = message.event
|
||||||
|
|
||||||
|
if isinstance(event, QueueErrorEvent):
|
||||||
|
raise self._handle_error(event)
|
||||||
|
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||||
|
if isinstance(event, QueueMessageEndEvent):
|
||||||
|
self._task_state.llm_result = event.llm_result
|
||||||
|
else:
|
||||||
|
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||||
|
model = model_config.model
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = 0
|
||||||
|
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||||
|
prompt_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model,
|
||||||
|
model_config.credentials,
|
||||||
|
self._task_state.llm_result.prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_tokens = 0
|
||||||
|
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||||
|
completion_tokens = model_type_instance.get_num_tokens(
|
||||||
|
model,
|
||||||
|
model_config.credentials,
|
||||||
|
[self._task_state.llm_result.message]
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials = model_config.credentials
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# response moderation
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
self._output_moderation_handler.stop_thread()
|
||||||
|
|
||||||
|
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||||
|
completion=self._task_state.llm_result.message.content,
|
||||||
|
public_event=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self._output_moderation_handler = None
|
||||||
|
|
||||||
|
replace_response = {
|
||||||
|
'event': 'message_replace',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'answer': self._task_state.llm_result.message.content,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._conversation.mode == 'chat':
|
||||||
|
replace_response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
|
yield self._yield_response(replace_response)
|
||||||
|
|
||||||
|
# Save message
|
||||||
|
self._save_message(self._task_state.llm_result)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'event': 'message_end',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'id': self._message.id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._conversation.mode == 'chat':
|
||||||
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
|
if self._task_state.metadata:
|
||||||
|
response['metadata'] = self._task_state.metadata
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
|
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||||
|
elif isinstance(event, AnnotationReplyEvent):
|
||||||
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
|
if annotation:
|
||||||
|
account = annotation.account
|
||||||
|
self._task_state.metadata['annotation_reply'] = {
|
||||||
|
'id': annotation.id,
|
||||||
|
'account': {
|
||||||
|
'id': annotation.account_id,
|
||||||
|
'name': account.name if account else 'Dify user'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self._task_state.llm_result.message.content = annotation.content
|
||||||
|
elif isinstance(event, QueueAgentThoughtEvent):
|
||||||
|
agent_thought = (
|
||||||
|
db.session.query(MessageAgentThought)
|
||||||
|
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_thought:
|
||||||
|
response = {
|
||||||
|
'event': 'agent_thought',
|
||||||
|
'id': agent_thought.id,
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'position': agent_thought.position,
|
||||||
|
'thought': agent_thought.thought,
|
||||||
|
'tool': agent_thought.tool,
|
||||||
|
'tool_input': agent_thought.tool_input,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._conversation.mode == 'chat':
|
||||||
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueMessageEvent):
|
||||||
|
chunk = event.chunk
|
||||||
|
delta_text = chunk.delta.message.content
|
||||||
|
if delta_text is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not self._task_state.llm_result.prompt_messages:
|
||||||
|
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||||
|
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
if self._output_moderation_handler.should_direct_output():
|
||||||
|
# stop subscribe new token when output moderation should direct output
|
||||||
|
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||||
|
self._queue_manager.publish_chunk_message(LLMResultChunk(
|
||||||
|
model=self._task_state.llm_result.model,
|
||||||
|
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self._output_moderation_handler.append_new_token(delta_text)
|
||||||
|
|
||||||
|
self._task_state.llm_result.message.content += delta_text
|
||||||
|
response = self._handle_chunk(delta_text)
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
|
response = {
|
||||||
|
'event': 'message_replace',
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'answer': event.text,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._conversation.mode == 'chat':
|
||||||
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
|
yield self._yield_response(response)
|
||||||
|
elif isinstance(event, QueuePingEvent):
|
||||||
|
yield "event: ping\n\n"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _save_message(self, llm_result: LLMResult) -> None:
|
||||||
|
"""
|
||||||
|
Save message.
|
||||||
|
:param llm_result: llm result
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
usage = llm_result.usage
|
||||||
|
|
||||||
|
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||||
|
|
||||||
|
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||||
|
self._message.message_tokens = usage.prompt_tokens
|
||||||
|
self._message.message_unit_price = usage.prompt_unit_price
|
||||||
|
self._message.message_price_unit = usage.prompt_price_unit
|
||||||
|
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||||
|
if llm_result.message.content else ''
|
||||||
|
self._message.answer_tokens = usage.completion_tokens
|
||||||
|
self._message.answer_unit_price = usage.completion_unit_price
|
||||||
|
self._message.answer_price_unit = usage.completion_price_unit
|
||||||
|
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||||
|
self._message.total_price = usage.total_price
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
message_was_created.send(
|
||||||
|
self._message,
|
||||||
|
application_generate_entity=self._application_generate_entity,
|
||||||
|
conversation=self._conversation,
|
||||||
|
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||||
|
extras=self._application_generate_entity.extras
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_chunk(self, text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Handle completed event.
|
||||||
|
:param text: text
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
'event': 'message',
|
||||||
|
'id': self._message.id,
|
||||||
|
'task_id': self._application_generate_entity.task_id,
|
||||||
|
'message_id': self._message.id,
|
||||||
|
'answer': text,
|
||||||
|
'created_at': int(self._message.created_at.timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._conversation.mode == 'chat':
|
||||||
|
response['conversation_id'] = self._conversation.id
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||||
|
"""
|
||||||
|
Handle error event.
|
||||||
|
:param event: event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
logger.debug("error: %s", event.error)
|
||||||
|
e = event.error
|
||||||
|
|
||||||
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
|
return InvokeAuthorizationError('Incorrect API key provided')
|
||||||
|
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||||
|
|
||||||
|
def _yield_response(self, response: dict) -> str:
|
||||||
|
"""
|
||||||
|
Yield response.
|
||||||
|
:param response: response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return "data: " + json.dumps(response) + "\n\n"
|
||||||
|
|
||||||
|
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Prompt messages to prompt for saving.
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = []
|
||||||
|
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
|
||||||
|
for prompt_message in prompt_messages:
|
||||||
|
if prompt_message.role == PromptMessageRole.USER:
|
||||||
|
role = 'user'
|
||||||
|
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||||
|
role = 'assistant'
|
||||||
|
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||||
|
role = 'system'
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = ''
|
||||||
|
files = []
|
||||||
|
if isinstance(prompt_message.content, list):
|
||||||
|
for content in prompt_message.content:
|
||||||
|
if content.type == PromptMessageContentType.TEXT:
|
||||||
|
content = cast(TextPromptMessageContent, content)
|
||||||
|
text += content.data
|
||||||
|
else:
|
||||||
|
content = cast(ImagePromptMessageContent, content)
|
||||||
|
files.append({
|
||||||
|
"type": 'image',
|
||||||
|
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||||
|
"detail": content.detail.value
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
text = prompt_message.content
|
||||||
|
|
||||||
|
prompts.append({
|
||||||
|
"role": role,
|
||||||
|
"text": text,
|
||||||
|
"files": files
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
prompts.append({
|
||||||
|
"role": 'user',
|
||||||
|
"text": prompt_messages[0].content
|
||||||
|
})
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
|
||||||
|
"""
|
||||||
|
Init output moderation.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
|
||||||
|
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
|
||||||
|
|
||||||
|
if sensitive_word_avoidance:
|
||||||
|
return OutputModerationHandler(
|
||||||
|
tenant_id=self._application_generate_entity.tenant_id,
|
||||||
|
app_id=self._application_generate_entity.app_id,
|
||||||
|
rule=ModerationRule(
|
||||||
|
type=sensitive_word_avoidance.type,
|
||||||
|
config=sensitive_word_avoidance.config
|
||||||
|
),
|
||||||
|
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||||
|
)
|
138
api/core/app_runner/moderation_handler.py
Normal file
138
api/core/app_runner/moderation_handler.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional, Dict
|
||||||
|
|
||||||
|
from flask import current_app, Flask
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||||
|
from core.moderation.factory import ModerationFactory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationRule(BaseModel):
|
||||||
|
type: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputModerationHandler(BaseModel):
|
||||||
|
DEFAULT_BUFFER_SIZE: int = 300
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
app_id: str
|
||||||
|
|
||||||
|
rule: ModerationRule
|
||||||
|
on_message_replace_func: Any
|
||||||
|
|
||||||
|
thread: Optional[threading.Thread] = None
|
||||||
|
thread_running: bool = True
|
||||||
|
buffer: str = ''
|
||||||
|
is_final_chunk: bool = False
|
||||||
|
final_output: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def should_direct_output(self):
|
||||||
|
return self.final_output is not None
|
||||||
|
|
||||||
|
def get_final_output(self):
|
||||||
|
return self.final_output
|
||||||
|
|
||||||
|
def append_new_token(self, token: str):
|
||||||
|
self.buffer += token
|
||||||
|
|
||||||
|
if not self.thread:
|
||||||
|
self.thread = self.start_thread()
|
||||||
|
|
||||||
|
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||||
|
self.buffer = completion
|
||||||
|
self.is_final_chunk = True
|
||||||
|
|
||||||
|
result = self.moderation(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
app_id=self.app_id,
|
||||||
|
moderation_buffer=completion
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result or not result.flagged:
|
||||||
|
return completion
|
||||||
|
|
||||||
|
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||||
|
final_output = result.preset_response
|
||||||
|
else:
|
||||||
|
final_output = result.text
|
||||||
|
|
||||||
|
if public_event:
|
||||||
|
self.on_message_replace_func(final_output)
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
def start_thread(self) -> threading.Thread:
|
||||||
|
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||||
|
thread = threading.Thread(target=self.worker, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||||
|
})
|
||||||
|
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return thread
|
||||||
|
|
||||||
|
def stop_thread(self):
|
||||||
|
if self.thread and self.thread.is_alive():
|
||||||
|
self.thread_running = False
|
||||||
|
|
||||||
|
def worker(self, flask_app: Flask, buffer_size: int):
|
||||||
|
with flask_app.app_context():
|
||||||
|
current_length = 0
|
||||||
|
while self.thread_running:
|
||||||
|
moderation_buffer = self.buffer
|
||||||
|
buffer_length = len(moderation_buffer)
|
||||||
|
if not self.is_final_chunk:
|
||||||
|
chunk_length = buffer_length - current_length
|
||||||
|
if 0 <= chunk_length < buffer_size:
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_length = buffer_length
|
||||||
|
|
||||||
|
result = self.moderation(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
app_id=self.app_id,
|
||||||
|
moderation_buffer=moderation_buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result or not result.flagged:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||||
|
final_output = result.preset_response
|
||||||
|
self.final_output = final_output
|
||||||
|
else:
|
||||||
|
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||||
|
|
||||||
|
# trigger replace event
|
||||||
|
if self.thread_running:
|
||||||
|
self.on_message_replace_func(final_output)
|
||||||
|
|
||||||
|
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||||
|
break
|
||||||
|
|
||||||
|
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||||
|
try:
|
||||||
|
moderation_factory = ModerationFactory(
|
||||||
|
name=self.rule.type,
|
||||||
|
app_id=app_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=self.rule.config
|
||||||
|
)
|
||||||
|
|
||||||
|
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Moderation Output error: %s", e)
|
||||||
|
|
||||||
|
return None
|
655
api/core/application_manager.py
Normal file
655
api/core/application_manager.py
Normal file
|
@ -0,0 +1,655 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from typing import cast, Optional, Any, Union, Generator, Tuple
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.app_runner.agent_app_runner import AgentApplicationRunner
|
||||||
|
from core.app_runner.basic_app_runner import BasicApplicationRunner
|
||||||
|
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
||||||
|
from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
|
||||||
|
ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
|
||||||
|
AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
|
||||||
|
AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
|
||||||
|
from core.entities.model_entities import ModelStatus
|
||||||
|
from core.file.file_obj import FileObj
|
||||||
|
from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.prompt_template import PromptTemplateParser
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account
|
||||||
|
from models.model import EndUser, Conversation, Message, MessageFile, App
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationManager:
|
||||||
|
"""
|
||||||
|
This class is responsible for managing application
|
||||||
|
"""
|
||||||
|
|
||||||
|
def generate(self, tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
app_model_config_id: str,
|
||||||
|
app_model_config_dict: dict,
|
||||||
|
app_model_config_override: bool,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
inputs: dict[str, str],
|
||||||
|
query: Optional[str] = None,
|
||||||
|
files: Optional[list[FileObj]] = None,
|
||||||
|
conversation: Optional[Conversation] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
extras: Optional[dict[str, Any]] = None) \
|
||||||
|
-> Union[dict, Generator]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param tenant_id: workspace ID
|
||||||
|
:param app_id: app ID
|
||||||
|
:param app_model_config_id: app model config id
|
||||||
|
:param app_model_config_dict: app model config dict
|
||||||
|
:param app_model_config_override: app model config override
|
||||||
|
:param user: account or end user
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:param files: file obj list
|
||||||
|
:param conversation: conversation
|
||||||
|
:param stream: is stream
|
||||||
|
:param extras: extras
|
||||||
|
"""
|
||||||
|
# init task id
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = ApplicationGenerateEntity(
|
||||||
|
task_id=task_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
app_model_config_id=app_model_config_id,
|
||||||
|
app_model_config_dict=app_model_config_dict,
|
||||||
|
app_orchestration_config_entity=self._convert_from_app_model_config_dict(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_model_config_dict=app_model_config_dict
|
||||||
|
),
|
||||||
|
app_model_config_override=app_model_config_override,
|
||||||
|
conversation_id=conversation.id if conversation else None,
|
||||||
|
inputs=conversation.inputs if conversation else inputs,
|
||||||
|
query=query.replace('\x00', '') if query else None,
|
||||||
|
files=files if files else [],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=stream,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
extras=extras
|
||||||
|
)
|
||||||
|
|
||||||
|
# init generate records
|
||||||
|
(
|
||||||
|
conversation,
|
||||||
|
message
|
||||||
|
) = self._init_generate_records(application_generate_entity)
|
||||||
|
|
||||||
|
# init queue manager
|
||||||
|
queue_manager = ApplicationQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
app_mode=conversation.mode,
|
||||||
|
message_id=message.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'application_generate_entity': application_generate_entity,
|
||||||
|
'queue_manager': queue_manager,
|
||||||
|
'conversation_id': conversation.id,
|
||||||
|
'message_id': message.id,
|
||||||
|
})
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
# return response or stream generator
|
||||||
|
return self._handle_response(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message,
|
||||||
|
stream=stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_worker(self, flask_app: Flask,
|
||||||
|
application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
conversation_id: str,
|
||||||
|
message_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Generate worker in a new thread.
|
||||||
|
:param flask_app: Flask app
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation_id: conversation ID
|
||||||
|
:param message_id: message ID
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
with flask_app.app_context():
|
||||||
|
try:
|
||||||
|
# get conversation and message
|
||||||
|
conversation = self._get_conversation(conversation_id)
|
||||||
|
message = self._get_message(message_id)
|
||||||
|
|
||||||
|
if application_generate_entity.app_orchestration_config_entity.agent:
|
||||||
|
# agent app
|
||||||
|
runner = AgentApplicationRunner()
|
||||||
|
runner.run(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# basic app
|
||||||
|
runner = BasicApplicationRunner()
|
||||||
|
runner.run(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
except ConversationTaskStoppedException:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
queue_manager.publish_error(e)
|
||||||
|
except (ValueError, InvokeError) as e:
|
||||||
|
queue_manager.publish_error(e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
queue_manager.publish_error(e)
|
||||||
|
finally:
|
||||||
|
db.session.remove()
|
||||||
|
|
||||||
|
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message,
|
||||||
|
stream: bool = False) -> Union[dict, Generator]:
|
||||||
|
"""
|
||||||
|
Handle response.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:param stream: is stream
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# init generate task pipeline
|
||||||
|
generate_task_pipeline = GenerateTaskPipeline(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return generate_task_pipeline.process(stream=stream)
|
||||||
|
except ValueError as e:
|
||||||
|
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise ConversationTaskStoppedException()
|
||||||
|
else:
|
||||||
|
logger.exception(e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
db.session.remove()
|
||||||
|
|
||||||
|
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||||
|
-> AppOrchestrationConfigEntity:
|
||||||
|
"""
|
||||||
|
Convert app model config dict to entity.
|
||||||
|
:param tenant_id: tenant ID
|
||||||
|
:param app_model_config_dict: app model config dict
|
||||||
|
:raises ProviderTokenNotInitError: provider token not init error
|
||||||
|
:return: app orchestration config entity
|
||||||
|
"""
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
copy_app_model_config_dict = app_model_config_dict.copy()
|
||||||
|
|
||||||
|
provider_manager = ProviderManager()
|
||||||
|
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=copy_app_model_config_dict['model']['provider'],
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_name = provider_model_bundle.configuration.provider.provider
|
||||||
|
model_name = copy_app_model_config_dict['model']['name']
|
||||||
|
|
||||||
|
model_type_instance = provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
# check model credentials
|
||||||
|
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
model=copy_app_model_config_dict['model']['name']
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_credentials is None:
|
||||||
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
|
|
||||||
|
# check model
|
||||||
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
|
model=copy_app_model_config_dict['model']['name'],
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_model is None:
|
||||||
|
model_name = copy_app_model_config_dict['model']['name']
|
||||||
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||||
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
|
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||||
|
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||||
|
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||||
|
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||||
|
|
||||||
|
# model config
|
||||||
|
completion_params = copy_app_model_config_dict['model'].get('completion_params')
|
||||||
|
stop = []
|
||||||
|
if 'stop' in completion_params:
|
||||||
|
stop = completion_params['stop']
|
||||||
|
del completion_params['stop']
|
||||||
|
|
||||||
|
# get model mode
|
||||||
|
model_mode = copy_app_model_config_dict['model'].get('mode')
|
||||||
|
if not model_mode:
|
||||||
|
mode_enum = model_type_instance.get_model_mode(
|
||||||
|
model=copy_app_model_config_dict['model']['name'],
|
||||||
|
credentials=model_credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
model_mode = mode_enum.value
|
||||||
|
|
||||||
|
model_schema = model_type_instance.get_model_schema(
|
||||||
|
copy_app_model_config_dict['model']['name'],
|
||||||
|
model_credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
|
properties['model_config'] = ModelConfigEntity(
|
||||||
|
provider=copy_app_model_config_dict['model']['provider'],
|
||||||
|
model=copy_app_model_config_dict['model']['name'],
|
||||||
|
model_schema=model_schema,
|
||||||
|
mode=model_mode,
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
credentials=model_credentials,
|
||||||
|
parameters=completion_params,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prompt template
|
||||||
|
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
|
||||||
|
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||||
|
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
|
||||||
|
properties['prompt_template'] = PromptTemplateEntity(
|
||||||
|
prompt_type=prompt_type,
|
||||||
|
simple_prompt_template=simple_prompt_template
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
advanced_chat_prompt_template = None
|
||||||
|
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
|
||||||
|
if chat_prompt_config:
|
||||||
|
chat_prompt_messages = []
|
||||||
|
for message in chat_prompt_config.get("prompt", []):
|
||||||
|
chat_prompt_messages.append({
|
||||||
|
"text": message["text"],
|
||||||
|
"role": PromptMessageRole.value_of(message["role"])
|
||||||
|
})
|
||||||
|
|
||||||
|
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||||
|
messages=chat_prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
advanced_completion_prompt_template = None
|
||||||
|
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
|
||||||
|
if completion_prompt_config:
|
||||||
|
completion_prompt_template_params = {
|
||||||
|
'prompt': completion_prompt_config['prompt']['text'],
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'conversation_histories_role' in completion_prompt_config:
|
||||||
|
completion_prompt_template_params['role_prefix'] = {
|
||||||
|
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||||
|
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||||
|
}
|
||||||
|
|
||||||
|
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||||
|
**completion_prompt_template_params
|
||||||
|
)
|
||||||
|
|
||||||
|
properties['prompt_template'] = PromptTemplateEntity(
|
||||||
|
prompt_type=prompt_type,
|
||||||
|
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||||
|
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||||
|
)
|
||||||
|
|
||||||
|
# external data variables
|
||||||
|
properties['external_data_variables'] = []
|
||||||
|
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
||||||
|
for external_data_tool in external_data_tools:
|
||||||
|
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
properties['external_data_variables'].append(
|
||||||
|
ExternalDataVariableEntity(
|
||||||
|
variable=external_data_tool['variable'],
|
||||||
|
type=external_data_tool['type'],
|
||||||
|
config=external_data_tool['config']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# show retrieve source
|
||||||
|
show_retrieve_source = False
|
||||||
|
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
|
||||||
|
if retriever_resource_dict:
|
||||||
|
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
|
||||||
|
show_retrieve_source = True
|
||||||
|
|
||||||
|
properties['show_retrieve_source'] = show_retrieve_source
|
||||||
|
|
||||||
|
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||||
|
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
||||||
|
'enabled']:
|
||||||
|
agent_dict = copy_app_model_config_dict.get('agent_mode')
|
||||||
|
if agent_dict['strategy'] in ['router', 'react_router']:
|
||||||
|
dataset_ids = []
|
||||||
|
for tool in agent_dict.get('tools', []):
|
||||||
|
key = list(tool.keys())[0]
|
||||||
|
|
||||||
|
if key != 'dataset':
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_item = tool[key]
|
||||||
|
|
||||||
|
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset_id = tool_item['id']
|
||||||
|
dataset_ids.append(dataset_id)
|
||||||
|
|
||||||
|
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||||
|
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||||
|
if dataset_configs['retrieval_model'] == 'single':
|
||||||
|
properties['dataset'] = DatasetEntity(
|
||||||
|
dataset_ids=dataset_ids,
|
||||||
|
retrieve_config=DatasetRetrieveConfigEntity(
|
||||||
|
query_variable=query_variable,
|
||||||
|
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||||
|
dataset_configs['retrieval_model']
|
||||||
|
),
|
||||||
|
single_strategy=agent_dict['strategy']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
properties['dataset'] = DatasetEntity(
|
||||||
|
dataset_ids=dataset_ids,
|
||||||
|
retrieve_config=DatasetRetrieveConfigEntity(
|
||||||
|
query_variable=query_variable,
|
||||||
|
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||||
|
dataset_configs['retrieval_model']
|
||||||
|
),
|
||||||
|
top_k=dataset_configs.get('top_k'),
|
||||||
|
score_threshold=dataset_configs.get('score_threshold'),
|
||||||
|
reranking_model=dataset_configs.get('reranking_model')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if agent_dict['strategy'] == 'react':
|
||||||
|
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||||
|
else:
|
||||||
|
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||||
|
|
||||||
|
agent_tools = []
|
||||||
|
for tool in agent_dict.get('tools', []):
|
||||||
|
key = list(tool.keys())[0]
|
||||||
|
tool_item = tool[key]
|
||||||
|
|
||||||
|
agent_tool_properties = {
|
||||||
|
"tool_id": key
|
||||||
|
}
|
||||||
|
|
||||||
|
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
agent_tool_properties["config"] = tool_item
|
||||||
|
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||||
|
|
||||||
|
properties['agent'] = AgentEntity(
|
||||||
|
provider=properties['model_config'].provider,
|
||||||
|
model=properties['model_config'].model,
|
||||||
|
strategy=strategy,
|
||||||
|
tools=agent_tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# file upload
|
||||||
|
file_upload_dict = copy_app_model_config_dict.get('file_upload')
|
||||||
|
if file_upload_dict:
|
||||||
|
if 'image' in file_upload_dict and file_upload_dict['image']:
|
||||||
|
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||||
|
properties['file_upload'] = FileUploadEntity(
|
||||||
|
image_config={
|
||||||
|
'number_limits': file_upload_dict['image']['number_limits'],
|
||||||
|
'detail': file_upload_dict['image']['detail'],
|
||||||
|
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# opening statement
|
||||||
|
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
|
||||||
|
|
||||||
|
# suggested questions after answer
|
||||||
|
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
|
||||||
|
if suggested_questions_after_answer_dict:
|
||||||
|
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
|
||||||
|
properties['suggested_questions_after_answer'] = True
|
||||||
|
|
||||||
|
# more like this
|
||||||
|
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
|
||||||
|
if more_like_this_dict:
|
||||||
|
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
|
||||||
|
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
|
||||||
|
|
||||||
|
# speech to text
|
||||||
|
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
|
||||||
|
if speech_to_text_dict:
|
||||||
|
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||||
|
properties['speech_to_text'] = True
|
||||||
|
|
||||||
|
# sensitive word avoidance
|
||||||
|
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||||
|
if sensitive_word_avoidance_dict:
|
||||||
|
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
|
||||||
|
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
|
||||||
|
type=sensitive_word_avoidance_dict.get('type'),
|
||||||
|
config=sensitive_word_avoidance_dict.get('config'),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AppOrchestrationConfigEntity(**properties)
|
||||||
|
|
||||||
|
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||||
|
-> Tuple[Conversation, Message]:
|
||||||
|
"""
|
||||||
|
Initialize generate records
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||||
|
|
||||||
|
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
model_schema = model_type_instance.get_model_schema(
|
||||||
|
model=app_orchestration_config_entity.model_config.model,
|
||||||
|
credentials=app_orchestration_config_entity.model_config.credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
app_record = (db.session.query(App)
|
||||||
|
.filter(App.id == application_generate_entity.app_id).first())
|
||||||
|
|
||||||
|
app_mode = app_record.mode
|
||||||
|
|
||||||
|
# get from source
|
||||||
|
end_user_id = None
|
||||||
|
account_id = None
|
||||||
|
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||||
|
from_source = 'api'
|
||||||
|
end_user_id = application_generate_entity.user_id
|
||||||
|
else:
|
||||||
|
from_source = 'console'
|
||||||
|
account_id = application_generate_entity.user_id
|
||||||
|
|
||||||
|
override_model_configs = None
|
||||||
|
if application_generate_entity.app_model_config_override:
|
||||||
|
override_model_configs = application_generate_entity.app_model_config_dict
|
||||||
|
|
||||||
|
introduction = ''
|
||||||
|
if app_mode == 'chat':
|
||||||
|
# get conversation introduction
|
||||||
|
introduction = self._get_conversation_introduction(application_generate_entity)
|
||||||
|
|
||||||
|
if not application_generate_entity.conversation_id:
|
||||||
|
conversation = Conversation(
|
||||||
|
app_id=app_record.id,
|
||||||
|
app_model_config_id=application_generate_entity.app_model_config_id,
|
||||||
|
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||||
|
model_id=app_orchestration_config_entity.model_config.model,
|
||||||
|
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||||
|
mode=app_mode,
|
||||||
|
name='New conversation',
|
||||||
|
inputs=application_generate_entity.inputs,
|
||||||
|
introduction=introduction,
|
||||||
|
system_instruction="",
|
||||||
|
system_instruction_tokens=0,
|
||||||
|
status='normal',
|
||||||
|
from_source=from_source,
|
||||||
|
from_end_user_id=end_user_id,
|
||||||
|
from_account_id=account_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(conversation)
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
conversation = (
|
||||||
|
db.session.query(Conversation)
|
||||||
|
.filter(
|
||||||
|
Conversation.id == application_generate_entity.conversation_id,
|
||||||
|
Conversation.app_id == app_record.id
|
||||||
|
).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
|
||||||
|
|
||||||
|
message = Message(
|
||||||
|
app_id=app_record.id,
|
||||||
|
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||||
|
model_id=app_orchestration_config_entity.model_config.model,
|
||||||
|
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
inputs=application_generate_entity.inputs,
|
||||||
|
query=application_generate_entity.query or "",
|
||||||
|
message="",
|
||||||
|
message_tokens=0,
|
||||||
|
message_unit_price=0,
|
||||||
|
message_price_unit=0,
|
||||||
|
answer="",
|
||||||
|
answer_tokens=0,
|
||||||
|
answer_unit_price=0,
|
||||||
|
answer_price_unit=0,
|
||||||
|
provider_response_latency=0,
|
||||||
|
total_price=0,
|
||||||
|
currency=currency,
|
||||||
|
from_source=from_source,
|
||||||
|
from_end_user_id=end_user_id,
|
||||||
|
from_account_id=account_id,
|
||||||
|
agent_based=app_orchestration_config_entity.agent is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(message)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
for file in application_generate_entity.files:
|
||||||
|
message_file = MessageFile(
|
||||||
|
message_id=message.id,
|
||||||
|
type=file.type.value,
|
||||||
|
transfer_method=file.transfer_method.value,
|
||||||
|
url=file.url,
|
||||||
|
upload_file_id=file.upload_file_id,
|
||||||
|
created_by_role=('account' if account_id else 'end_user'),
|
||||||
|
created_by=account_id or end_user_id,
|
||||||
|
)
|
||||||
|
db.session.add(message_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return conversation, message
|
||||||
|
|
||||||
|
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
|
||||||
|
"""
|
||||||
|
Get conversation introduction
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:return: conversation introduction
|
||||||
|
"""
|
||||||
|
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||||
|
introduction = app_orchestration_config_entity.opening_statement
|
||||||
|
|
||||||
|
if introduction:
|
||||||
|
try:
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
prompt_template = PromptTemplateParser(template=introduction)
|
||||||
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||||
|
introduction = prompt_template.format(prompt_inputs)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return introduction
|
||||||
|
|
||||||
|
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""
|
||||||
|
Get conversation by conversation id
|
||||||
|
:param conversation_id: conversation id
|
||||||
|
:return: conversation
|
||||||
|
"""
|
||||||
|
conversation = (
|
||||||
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.id == conversation_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
def _get_message(self, message_id: str) -> Message:
|
||||||
|
"""
|
||||||
|
Get message by message id
|
||||||
|
:param message_id: message id
|
||||||
|
:return: message
|
||||||
|
"""
|
||||||
|
message = (
|
||||||
|
db.session.query(Message)
|
||||||
|
.filter(Message.id == message_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return message
|
228
api/core/application_queue_manager.py
Normal file
228
api/core/application_queue_manager.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from typing import Generator, Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeMeta
|
||||||
|
|
||||||
|
from core.entities.application_entities import InvokeFrom
|
||||||
|
from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
|
||||||
|
QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
|
||||||
|
QueueMessageEvent, QueueMessage, AnnotationReplyEvent
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.model import MessageAgentThought
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationQueueManager:
|
||||||
|
def __init__(self, task_id: str,
|
||||||
|
user_id: str,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
conversation_id: str,
|
||||||
|
app_mode: str,
|
||||||
|
message_id: str) -> None:
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("user is required")
|
||||||
|
|
||||||
|
self._task_id = task_id
|
||||||
|
self._user_id = user_id
|
||||||
|
self._invoke_from = invoke_from
|
||||||
|
self._conversation_id = str(conversation_id)
|
||||||
|
self._app_mode = app_mode
|
||||||
|
self._message_id = str(message_id)
|
||||||
|
|
||||||
|
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||||
|
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||||
|
|
||||||
|
q = queue.Queue()
|
||||||
|
|
||||||
|
self._q = q
|
||||||
|
|
||||||
|
def listen(self) -> Generator:
|
||||||
|
"""
|
||||||
|
Listen to queue
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# wait for 10 minutes to stop listen
|
||||||
|
listen_timeout = 600
|
||||||
|
start_time = time.time()
|
||||||
|
last_ping_time = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = self._q.get(timeout=1)
|
||||||
|
if message is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield message
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||||
|
# publish two messages to make sure the client can receive the stop signal
|
||||||
|
# and stop listening after the stop signal processed
|
||||||
|
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
if elapsed_time // 10 > last_ping_time:
|
||||||
|
self.publish(QueuePingEvent())
|
||||||
|
last_ping_time = elapsed_time // 10
|
||||||
|
|
||||||
|
def stop_listen(self) -> None:
|
||||||
|
"""
|
||||||
|
Stop listen to queue
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self._q.put(None)
|
||||||
|
|
||||||
|
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
|
||||||
|
"""
|
||||||
|
Publish chunk message to channel
|
||||||
|
|
||||||
|
:param chunk: chunk
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(QueueMessageEvent(
|
||||||
|
chunk=chunk
|
||||||
|
))
|
||||||
|
|
||||||
|
def publish_message_replace(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Publish message replace
|
||||||
|
:param text: text
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(QueueMessageReplaceEvent(
|
||||||
|
text=text
|
||||||
|
))
|
||||||
|
|
||||||
|
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
|
||||||
|
"""
|
||||||
|
Publish retriever resources
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
|
||||||
|
|
||||||
|
def publish_annotation_reply(self, message_annotation_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Publish annotation reply
|
||||||
|
:param message_annotation_id: message annotation id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
|
||||||
|
|
||||||
|
def publish_message_end(self, llm_result: LLMResult) -> None:
|
||||||
|
"""
|
||||||
|
Publish message end
|
||||||
|
:param llm_result: llm result
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(QueueMessageEndEvent(llm_result=llm_result))
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||||
|
"""
|
||||||
|
Publish agent thought
|
||||||
|
:param message_agent_thought: message agent thought
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(QueueAgentThoughtEvent(
|
||||||
|
agent_thought_id=message_agent_thought.id
|
||||||
|
))
|
||||||
|
|
||||||
|
def publish_error(self, e) -> None:
|
||||||
|
"""
|
||||||
|
Publish error
|
||||||
|
:param e: error
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.publish(QueueErrorEvent(
|
||||||
|
error=e
|
||||||
|
))
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
def publish(self, event: AppQueueEvent) -> None:
|
||||||
|
"""
|
||||||
|
Publish event to queue
|
||||||
|
:param event:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self._check_for_sqlalchemy_models(event.dict())
|
||||||
|
|
||||||
|
message = QueueMessage(
|
||||||
|
task_id=self._task_id,
|
||||||
|
message_id=self._message_id,
|
||||||
|
conversation_id=self._conversation_id,
|
||||||
|
app_mode=self._app_mode,
|
||||||
|
event=event
|
||||||
|
)
|
||||||
|
|
||||||
|
self._q.put(message)
|
||||||
|
|
||||||
|
if isinstance(event, QueueStopEvent):
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set task stop flag
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||||
|
if result is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||||
|
if result != f"{user_prefix}-{user_id}":
|
||||||
|
return
|
||||||
|
|
||||||
|
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||||
|
redis_client.setex(stopped_cache_key, 600, 1)
|
||||||
|
|
||||||
|
def _is_stopped(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if task is stopped
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
||||||
|
result = redis_client.get(stopped_cache_key)
|
||||||
|
if result is not None:
|
||||||
|
redis_client.delete(stopped_cache_key)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate task belong cache key
|
||||||
|
:param task_id: task id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return f"generate_task_belong:{task_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_stopped_cache_key(cls, task_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate stopped cache key
|
||||||
|
:param task_id: task id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return f"generate_task_stopped:{task_id}"
|
||||||
|
|
||||||
|
def _check_for_sqlalchemy_models(self, data: Any):
|
||||||
|
# from entity to dict or list
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key, value in data.items():
|
||||||
|
self._check_for_sqlalchemy_models(value)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
for item in data:
|
||||||
|
self._check_for_sqlalchemy_models(item)
|
||||||
|
else:
|
||||||
|
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||||
|
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||||
|
"that cause thread safety issues is not allowed.")
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationTaskStoppedException(Exception):
|
||||||
|
pass
|
|
@ -2,30 +2,40 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, List, Union, Optional, cast
|
||||||
|
|
||||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||||
|
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_providers.models.entity.message import PromptMessage
|
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import MessageChain, MessageAgentThought, Message
|
||||||
|
|
||||||
|
|
||||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
raise_error: bool = True
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
def __init__(self, model_config: ModelConfigEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
message: Message,
|
||||||
|
message_chain: MessageChain) -> None:
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
self.model_instance = model_instance
|
self.model_config = model_config
|
||||||
self.conversation_message_task = conversation_message_task
|
self.queue_manager = queue_manager
|
||||||
|
self.message = message
|
||||||
|
self.message_chain = message_chain
|
||||||
|
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||||
|
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
self._agent_loops = []
|
self._agent_loops = []
|
||||||
self._current_loop = None
|
self._current_loop = None
|
||||||
self._message_agent_thought = None
|
self._message_agent_thought = None
|
||||||
self.current_chain = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_loops(self) -> List[AgentLoop]:
|
def agent_loops(self) -> List[AgentLoop]:
|
||||||
|
@ -46,65 +56,60 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||||
"""Whether to ignore chain callbacks."""
|
"""Whether to ignore chain callbacks."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
|
||||||
|
if not self._current_loop:
|
||||||
|
# Agent start with a LLM query
|
||||||
|
self._current_loop = AgentLoop(
|
||||||
|
position=len(self._agent_loops) + 1,
|
||||||
|
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
|
||||||
|
status='llm_started',
|
||||||
|
started_at=time.perf_counter()
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
|
||||||
|
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||||
|
self._current_loop.status = 'llm_end'
|
||||||
|
if result.usage:
|
||||||
|
self._current_loop.prompt_tokens = result.usage.prompt_tokens
|
||||||
|
else:
|
||||||
|
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
|
||||||
|
model=self.model_config.model,
|
||||||
|
credentials=self.model_config.credentials,
|
||||||
|
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_message = result.message
|
||||||
|
if completion_message.tool_calls:
|
||||||
|
self._current_loop.completion \
|
||||||
|
= json.dumps({'function_call': completion_message.tool_calls})
|
||||||
|
else:
|
||||||
|
self._current_loop.completion = completion_message.content
|
||||||
|
|
||||||
|
if result.usage:
|
||||||
|
self._current_loop.completion_tokens = result.usage.completion_tokens
|
||||||
|
else:
|
||||||
|
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
|
||||||
|
model=self.model_config.model,
|
||||||
|
credentials=self.model_config.credentials,
|
||||||
|
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
|
||||||
|
)
|
||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not self._current_loop:
|
pass
|
||||||
# Agent start with a LLM query
|
|
||||||
self._current_loop = AgentLoop(
|
|
||||||
position=len(self._agent_loops) + 1,
|
|
||||||
prompt="\n".join([message.content for message in messages[0]]),
|
|
||||||
status='llm_started',
|
|
||||||
started_at=time.perf_counter()
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the prompts."""
|
pass
|
||||||
# serialized={'name': 'OpenAI'}
|
|
||||||
# prompts=['Answer the following questions...\nThought:']
|
|
||||||
# kwargs={}
|
|
||||||
if not self._current_loop:
|
|
||||||
# Agent start with a LLM query
|
|
||||||
self._current_loop = AgentLoop(
|
|
||||||
position=len(self._agent_loops) + 1,
|
|
||||||
prompt=prompts[0],
|
|
||||||
status='llm_started',
|
|
||||||
started_at=time.perf_counter()
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
# kwargs={}
|
pass
|
||||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
|
||||||
self._current_loop.status = 'llm_end'
|
|
||||||
if response.llm_output:
|
|
||||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
|
||||||
else:
|
|
||||||
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
|
||||||
[PromptMessage(content=self._current_loop.prompt)]
|
|
||||||
)
|
|
||||||
completion_generation = response.generations[0][0]
|
|
||||||
if isinstance(completion_generation, ChatGeneration):
|
|
||||||
completion_message = completion_generation.message
|
|
||||||
if 'function_call' in completion_message.additional_kwargs:
|
|
||||||
self._current_loop.completion \
|
|
||||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
|
||||||
else:
|
|
||||||
self._current_loop.completion = response.generations[0][0].text
|
|
||||||
else:
|
|
||||||
self._current_loop.completion = completion_generation.text
|
|
||||||
|
|
||||||
if response.llm_output:
|
|
||||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
|
||||||
else:
|
|
||||||
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
|
||||||
[PromptMessage(content=self._current_loop.completion)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||||
if completion is not None:
|
if completion is not None:
|
||||||
self._current_loop.completion = completion
|
self._current_loop.completion = completion
|
||||||
|
|
||||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
self._message_agent_thought = self._init_agent_thought()
|
||||||
self.current_chain,
|
|
||||||
self._current_loop
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
|
@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||||
self._current_loop.completed_at = time.perf_counter()
|
self._current_loop.completed_at = time.perf_counter()
|
||||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||||
|
|
||||||
self.conversation_message_task.on_agent_end(
|
self._complete_agent_thought(self._message_agent_thought)
|
||||||
self._message_agent_thought, self.model_instance, self._current_loop
|
|
||||||
)
|
|
||||||
|
|
||||||
self._agent_loops.append(self._current_loop)
|
self._agent_loops.append(self._current_loop)
|
||||||
self._current_loop = None
|
self._current_loop = None
|
||||||
|
@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||||
self._current_loop.completed_at = time.perf_counter()
|
self._current_loop.completed_at = time.perf_counter()
|
||||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||||
self._current_loop.thought = '[DONE]'
|
self._current_loop.thought = '[DONE]'
|
||||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
self._message_agent_thought = self._init_agent_thought()
|
||||||
self.current_chain,
|
|
||||||
self._current_loop
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conversation_message_task.on_agent_end(
|
self._complete_agent_thought(self._message_agent_thought)
|
||||||
self._message_agent_thought, self.model_instance, self._current_loop
|
|
||||||
)
|
|
||||||
|
|
||||||
self._agent_loops.append(self._current_loop)
|
self._agent_loops.append(self._current_loop)
|
||||||
self._current_loop = None
|
self._current_loop = None
|
||||||
self._message_agent_thought = None
|
self._message_agent_thought = None
|
||||||
elif not self._current_loop and self._agent_loops:
|
elif not self._current_loop and self._agent_loops:
|
||||||
self._agent_loops[-1].status = 'agent_finish'
|
self._agent_loops[-1].status = 'agent_finish'
|
||||||
|
|
||||||
|
def _init_agent_thought(self) -> MessageAgentThought:
|
||||||
|
message_agent_thought = MessageAgentThought(
|
||||||
|
message_id=self.message.id,
|
||||||
|
message_chain_id=self.message_chain.id,
|
||||||
|
position=self._current_loop.position,
|
||||||
|
thought=self._current_loop.thought,
|
||||||
|
tool=self._current_loop.tool_name,
|
||||||
|
tool_input=self._current_loop.tool_input,
|
||||||
|
message=self._current_loop.prompt,
|
||||||
|
message_price_unit=0,
|
||||||
|
answer=self._current_loop.completion,
|
||||||
|
answer_price_unit=0,
|
||||||
|
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
|
||||||
|
created_by=(self.message.from_account_id
|
||||||
|
if self.message.from_source == 'console' else self.message.from_end_user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(message_agent_thought)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
self.queue_manager.publish_agent_thought(message_agent_thought)
|
||||||
|
|
||||||
|
return message_agent_thought
|
||||||
|
|
||||||
|
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||||
|
loop_message_tokens = self._current_loop.prompt_tokens
|
||||||
|
loop_answer_tokens = self._current_loop.completion_tokens
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
llm_usage = self.model_type_instance._calc_response_usage(
|
||||||
|
self.model_config.model,
|
||||||
|
self.model_config.credentials,
|
||||||
|
loop_message_tokens,
|
||||||
|
loop_answer_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
message_agent_thought.observation = self._current_loop.tool_output
|
||||||
|
message_agent_thought.tool_process_data = '' # currently not support
|
||||||
|
message_agent_thought.message_token = loop_message_tokens
|
||||||
|
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||||
|
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||||
|
message_agent_thought.answer_token = loop_answer_tokens
|
||||||
|
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||||
|
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||||
|
message_agent_thought.latency = self._current_loop.latency
|
||||||
|
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
|
||||||
|
message_agent_thought.total_price = llm_usage.total_price
|
||||||
|
message_agent_thought.currency = llm_usage.currency
|
||||||
|
db.session.commit()
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from json import JSONDecodeError
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Union, Optional
|
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
|
||||||
|
|
||||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
|
||||||
"""Callback Handler that prints to std out."""
|
|
||||||
raise_error: bool = True
|
|
||||||
|
|
||||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
|
||||||
"""Initialize callback handler."""
|
|
||||||
self.queries = []
|
|
||||||
self.conversation_message_task = conversation_message_task
|
|
||||||
|
|
||||||
@property
|
|
||||||
def always_verbose(self) -> bool:
|
|
||||||
"""Whether to call verbose callbacks even if verbose is False."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ignore_llm(self) -> bool:
|
|
||||||
"""Whether to ignore LLM callbacks."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ignore_chain(self) -> bool:
|
|
||||||
"""Whether to ignore chain callbacks."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ignore_agent(self) -> bool:
|
|
||||||
"""Whether to ignore agent callbacks."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def on_tool_start(
|
|
||||||
self,
|
|
||||||
serialized: Dict[str, Any],
|
|
||||||
input_str: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
tool_name: str = serialized.get('name')
|
|
||||||
dataset_id = tool_name.removeprefix('dataset-')
|
|
||||||
|
|
||||||
try:
|
|
||||||
input_dict = json.loads(input_str.replace("'", "\""))
|
|
||||||
query = input_dict.get('query')
|
|
||||||
except JSONDecodeError:
|
|
||||||
query = input_str
|
|
||||||
|
|
||||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
|
||||||
|
|
||||||
def on_tool_end(
|
|
||||||
self,
|
|
||||||
output: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
observation_prefix: Optional[str] = None,
|
|
||||||
llm_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_tool_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
logging.debug("Dataset tool on_llm_error: %s", error)
|
|
|
@ -1,16 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class ChainResult(BaseModel):
|
|
||||||
type: str = None
|
|
||||||
prompt: dict = None
|
|
||||||
completion: dict = None
|
|
||||||
|
|
||||||
status: str = 'chain_started'
|
|
||||||
completed: bool = False
|
|
||||||
|
|
||||||
started_at: float = None
|
|
||||||
completed_at: float = None
|
|
||||||
|
|
||||||
agent_result: dict = None
|
|
||||||
"""only when type is 'AgentExecutor'"""
|
|
|
@ -1,6 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetQueryObj(BaseModel):
|
|
||||||
dataset_id: str = None
|
|
||||||
query: str = None
|
|
|
@ -1,8 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage(BaseModel):
|
|
||||||
prompt: str = ''
|
|
||||||
prompt_tokens: int = 0
|
|
||||||
completion: str = ''
|
|
||||||
completion_tokens: int = 0
|
|
|
@ -1,17 +1,44 @@
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.entities.application_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment
|
from models.dataset import DocumentSegment, DatasetQuery
|
||||||
|
from models.model import DatasetRetrieverResource
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexToolCallbackHandler:
|
class DatasetIndexToolCallbackHandler:
|
||||||
"""Callback handler for dataset tool."""
|
"""Callback handler for dataset tool."""
|
||||||
|
|
||||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
def __init__(self, queue_manager: ApplicationQueueManager,
|
||||||
self.conversation_message_task = conversation_message_task
|
app_id: str,
|
||||||
|
message_id: str,
|
||||||
|
user_id: str,
|
||||||
|
invoke_from: InvokeFrom) -> None:
|
||||||
|
self._queue_manager = queue_manager
|
||||||
|
self._app_id = app_id
|
||||||
|
self._message_id = message_id
|
||||||
|
self._user_id = user_id
|
||||||
|
self._invoke_from = invoke_from
|
||||||
|
|
||||||
|
def on_query(self, query: str, dataset_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Handle query.
|
||||||
|
"""
|
||||||
|
dataset_query = DatasetQuery(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
content=query,
|
||||||
|
source='app',
|
||||||
|
source_app_id=self._app_id,
|
||||||
|
created_by_role=('account'
|
||||||
|
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||||
|
created_by=self._user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(dataset_query)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def on_tool_end(self, documents: List[Document]) -> None:
|
def on_tool_end(self, documents: List[Document]) -> None:
|
||||||
"""Handle tool end."""
|
"""Handle tool end."""
|
||||||
|
@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
|
||||||
|
|
||||||
def return_retriever_resource_info(self, resource: List):
|
def return_retriever_resource_info(self, resource: List):
|
||||||
"""Handle return_retriever_resource_info."""
|
"""Handle return_retriever_resource_info."""
|
||||||
self.conversation_message_task.on_dataset_query_finish(resource)
|
if resource and len(resource) > 0:
|
||||||
|
for item in resource:
|
||||||
|
dataset_retriever_resource = DatasetRetrieverResource(
|
||||||
|
message_id=self._message_id,
|
||||||
|
position=item.get('position'),
|
||||||
|
dataset_id=item.get('dataset_id'),
|
||||||
|
dataset_name=item.get('dataset_name'),
|
||||||
|
document_id=item.get('document_id'),
|
||||||
|
document_name=item.get('document_name'),
|
||||||
|
data_source_type=item.get('data_source_type'),
|
||||||
|
segment_id=item.get('segment_id'),
|
||||||
|
score=item.get('score') if 'score' in item else None,
|
||||||
|
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||||
|
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||||
|
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||||
|
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||||
|
content=item.get('content'),
|
||||||
|
retriever_from=item.get('retriever_from'),
|
||||||
|
created_by=self._user_id
|
||||||
|
)
|
||||||
|
db.session.add(dataset_retriever_resource)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
self._queue_manager.publish_retriever_resources(resource)
|
||||||
|
|
|
@ -1,284 +0,0 @@
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List, Union, Optional
|
|
||||||
|
|
||||||
from flask import Flask, current_app
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
|
||||||
from langchain.schema import LLMResult, BaseMessage
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.callback_handler.entity.llm_message import LLMMessage
|
|
||||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
|
||||||
ConversationTaskInterruptException
|
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
|
||||||
ImagePromptMessageFile
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
|
||||||
from core.moderation.factory import ModerationFactory
|
|
||||||
|
|
||||||
|
|
||||||
class ModerationRule(BaseModel):
|
|
||||||
type: str
|
|
||||||
config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class LLMCallbackHandler(BaseCallbackHandler):
|
|
||||||
raise_error: bool = True
|
|
||||||
|
|
||||||
def __init__(self, model_instance: BaseLLM,
|
|
||||||
conversation_message_task: ConversationMessageTask):
|
|
||||||
self.model_instance = model_instance
|
|
||||||
self.llm_message = LLMMessage()
|
|
||||||
self.start_at = None
|
|
||||||
self.conversation_message_task = conversation_message_task
|
|
||||||
|
|
||||||
self.output_moderation_handler = None
|
|
||||||
self.init_output_moderation()
|
|
||||||
|
|
||||||
def init_output_moderation(self):
|
|
||||||
app_model_config = self.conversation_message_task.app_model_config
|
|
||||||
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
|
|
||||||
|
|
||||||
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
|
|
||||||
self.output_moderation_handler = OutputModerationHandler(
|
|
||||||
tenant_id=self.conversation_message_task.tenant_id,
|
|
||||||
app_id=self.conversation_message_task.app.id,
|
|
||||||
rule=ModerationRule(
|
|
||||||
type=sensitive_word_avoidance_dict.get("type"),
|
|
||||||
config=sensitive_word_avoidance_dict.get("config")
|
|
||||||
),
|
|
||||||
on_message_replace_func=self.conversation_message_task.on_message_replace
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def always_verbose(self) -> bool:
|
|
||||||
"""Whether to call verbose callbacks even if verbose is False."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def on_chat_model_start(
|
|
||||||
self,
|
|
||||||
serialized: Dict[str, Any],
|
|
||||||
messages: List[List[BaseMessage]],
|
|
||||||
**kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
real_prompts = []
|
|
||||||
for message in messages[0]:
|
|
||||||
if message.type == 'human':
|
|
||||||
role = 'user'
|
|
||||||
elif message.type == 'ai':
|
|
||||||
role = 'assistant'
|
|
||||||
else:
|
|
||||||
role = 'system'
|
|
||||||
|
|
||||||
real_prompts.append({
|
|
||||||
"role": role,
|
|
||||||
"text": message.content,
|
|
||||||
"files": [{
|
|
||||||
"type": file.type.value,
|
|
||||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
|
||||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
|
||||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
|
||||||
})
|
|
||||||
|
|
||||||
self.llm_message.prompt = real_prompts
|
|
||||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
|
|
||||||
|
|
||||||
def on_llm_start(
|
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self.llm_message.prompt = [{
|
|
||||||
"role": 'user',
|
|
||||||
"text": prompts[0]
|
|
||||||
}]
|
|
||||||
|
|
||||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
||||||
if self.output_moderation_handler:
|
|
||||||
self.output_moderation_handler.stop_thread()
|
|
||||||
|
|
||||||
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
|
|
||||||
completion=response.generations[0][0].text,
|
|
||||||
public_event=True if self.conversation_message_task.streaming else False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.llm_message.completion = response.generations[0][0].text
|
|
||||||
|
|
||||||
if not self.conversation_message_task.streaming:
|
|
||||||
self.conversation_message_task.append_message_text(self.llm_message.completion)
|
|
||||||
|
|
||||||
if response.llm_output and 'token_usage' in response.llm_output:
|
|
||||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
|
||||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
|
||||||
|
|
||||||
if 'completion_tokens' in response.llm_output['token_usage']:
|
|
||||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
|
||||||
else:
|
|
||||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
|
||||||
[PromptMessage(content=self.llm_message.completion)])
|
|
||||||
else:
|
|
||||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
|
||||||
[PromptMessage(content=self.llm_message.completion)])
|
|
||||||
|
|
||||||
self.conversation_message_task.save_message(self.llm_message)
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
||||||
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
|
|
||||||
# stop subscribe new token when output moderation should direct output
|
|
||||||
ex = ConversationTaskInterruptException()
|
|
||||||
self.on_llm_error(error=ex)
|
|
||||||
raise ex
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.conversation_message_task.append_message_text(token)
|
|
||||||
self.llm_message.completion += token
|
|
||||||
|
|
||||||
if self.output_moderation_handler:
|
|
||||||
self.output_moderation_handler.append_new_token(token)
|
|
||||||
except ConversationTaskStoppedException as ex:
|
|
||||||
self.on_llm_error(error=ex)
|
|
||||||
raise ex
|
|
||||||
|
|
||||||
def on_llm_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
if self.output_moderation_handler:
|
|
||||||
self.output_moderation_handler.stop_thread()
|
|
||||||
|
|
||||||
if isinstance(error, ConversationTaskStoppedException):
|
|
||||||
if self.conversation_message_task.streaming:
|
|
||||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
|
||||||
[PromptMessage(content=self.llm_message.completion)]
|
|
||||||
)
|
|
||||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
|
||||||
if isinstance(error, ConversationTaskInterruptException):
|
|
||||||
self.llm_message.completion = self.output_moderation_handler.get_final_output()
|
|
||||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
|
||||||
[PromptMessage(content=self.llm_message.completion)]
|
|
||||||
)
|
|
||||||
self.conversation_message_task.save_message(llm_message=self.llm_message)
|
|
||||||
else:
|
|
||||||
logging.debug("on_llm_error: %s", error)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputModerationHandler(BaseModel):
|
|
||||||
DEFAULT_BUFFER_SIZE: int = 300
|
|
||||||
|
|
||||||
tenant_id: str
|
|
||||||
app_id: str
|
|
||||||
|
|
||||||
rule: ModerationRule
|
|
||||||
on_message_replace_func: Any
|
|
||||||
|
|
||||||
thread: Optional[threading.Thread] = None
|
|
||||||
thread_running: bool = True
|
|
||||||
buffer: str = ''
|
|
||||||
is_final_chunk: bool = False
|
|
||||||
final_output: Optional[str] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def should_direct_output(self):
|
|
||||||
return self.final_output is not None
|
|
||||||
|
|
||||||
def get_final_output(self):
|
|
||||||
return self.final_output
|
|
||||||
|
|
||||||
def append_new_token(self, token: str):
|
|
||||||
self.buffer += token
|
|
||||||
|
|
||||||
if not self.thread:
|
|
||||||
self.thread = self.start_thread()
|
|
||||||
|
|
||||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
|
||||||
self.buffer = completion
|
|
||||||
self.is_final_chunk = True
|
|
||||||
|
|
||||||
result = self.moderation(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_id,
|
|
||||||
moderation_buffer=completion
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result or not result.flagged:
|
|
||||||
return completion
|
|
||||||
|
|
||||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
|
||||||
final_output = result.preset_response
|
|
||||||
else:
|
|
||||||
final_output = result.text
|
|
||||||
|
|
||||||
if public_event:
|
|
||||||
self.on_message_replace_func(final_output)
|
|
||||||
|
|
||||||
return final_output
|
|
||||||
|
|
||||||
def start_thread(self) -> threading.Thread:
|
|
||||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
|
||||||
thread = threading.Thread(target=self.worker, kwargs={
|
|
||||||
'flask_app': current_app._get_current_object(),
|
|
||||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
|
||||||
})
|
|
||||||
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
return thread
|
|
||||||
|
|
||||||
def stop_thread(self):
|
|
||||||
if self.thread and self.thread.is_alive():
|
|
||||||
self.thread_running = False
|
|
||||||
|
|
||||||
def worker(self, flask_app: Flask, buffer_size: int):
|
|
||||||
with flask_app.app_context():
|
|
||||||
current_length = 0
|
|
||||||
while self.thread_running:
|
|
||||||
moderation_buffer = self.buffer
|
|
||||||
buffer_length = len(moderation_buffer)
|
|
||||||
if not self.is_final_chunk:
|
|
||||||
chunk_length = buffer_length - current_length
|
|
||||||
if 0 <= chunk_length < buffer_size:
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
current_length = buffer_length
|
|
||||||
|
|
||||||
result = self.moderation(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_id,
|
|
||||||
moderation_buffer=moderation_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result or not result.flagged:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
|
||||||
final_output = result.preset_response
|
|
||||||
self.final_output = final_output
|
|
||||||
else:
|
|
||||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
|
||||||
|
|
||||||
# trigger replace event
|
|
||||||
if self.thread_running:
|
|
||||||
self.on_message_replace_func(final_output)
|
|
||||||
|
|
||||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
|
||||||
break
|
|
||||||
|
|
||||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
|
||||||
try:
|
|
||||||
moderation_factory = ModerationFactory(
|
|
||||||
name=self.rule.type,
|
|
||||||
app_id=app_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=self.rule.config
|
|
||||||
)
|
|
||||||
|
|
||||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Moderation Output error: %s", e)
|
|
||||||
|
|
||||||
return None
|
|
|
@ -1,76 +0,0 @@
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
from typing import Any, Dict, Union
|
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
|
||||||
|
|
||||||
from core.callback_handler.entity.chain_result import ChainResult
|
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
|
||||||
|
|
||||||
|
|
||||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
|
||||||
"""Callback Handler that prints to std out."""
|
|
||||||
raise_error: bool = True
|
|
||||||
|
|
||||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
|
||||||
"""Initialize callback handler."""
|
|
||||||
self._current_chain_result = None
|
|
||||||
self._current_chain_message = None
|
|
||||||
self.conversation_message_task = conversation_message_task
|
|
||||||
self.agent_callback = None
|
|
||||||
|
|
||||||
def clear_chain_results(self) -> None:
|
|
||||||
self._current_chain_result = None
|
|
||||||
self._current_chain_message = None
|
|
||||||
if self.agent_callback:
|
|
||||||
self.agent_callback.current_chain = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def always_verbose(self) -> bool:
|
|
||||||
"""Whether to call verbose callbacks even if verbose is False."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ignore_llm(self) -> bool:
|
|
||||||
"""Whether to ignore LLM callbacks."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ignore_agent(self) -> bool:
|
|
||||||
"""Whether to ignore agent callbacks."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def on_chain_start(
|
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Print out that we are entering a chain."""
|
|
||||||
if not self._current_chain_result:
|
|
||||||
chain_type = serialized['id'][-1]
|
|
||||||
if chain_type:
|
|
||||||
self._current_chain_result = ChainResult(
|
|
||||||
type=chain_type,
|
|
||||||
prompt=inputs,
|
|
||||||
started_at=time.perf_counter()
|
|
||||||
)
|
|
||||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
|
||||||
if self.agent_callback:
|
|
||||||
self.agent_callback.current_chain = self._current_chain_message
|
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
||||||
"""Print out that we finished a chain."""
|
|
||||||
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
|
|
||||||
self._current_chain_result.status = 'chain_ended'
|
|
||||||
self._current_chain_result.completion = outputs
|
|
||||||
self._current_chain_result.completed = True
|
|
||||||
self._current_chain_result.completed_at = time.perf_counter()
|
|
||||||
|
|
||||||
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
|
|
||||||
|
|
||||||
self.clear_chain_results()
|
|
||||||
|
|
||||||
def on_chain_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
logging.debug("Dataset tool on_chain_error: %s", error)
|
|
||||||
self.clear_chain_results()
|
|
|
@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||||
"""Run on agent action."""
|
"""Run on agent action."""
|
||||||
tool = action.tool
|
tool = action.tool
|
||||||
tool_input = action.tool_input
|
tool_input = action.tool_input
|
||||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
try:
|
||||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||||
|
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||||
|
except ValueError:
|
||||||
|
thought = ''
|
||||||
|
|
||||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||||
|
|
|
@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.schema import LLMResult, Generation
|
from langchain.schema import LLMResult, Generation
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
from core.third_party.langchain.llms.fake import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(LCLLMChain):
|
class LLMChain(LCLLMChain):
|
||||||
model_instance: BaseLLM
|
model_config: ModelConfigEntity
|
||||||
"""The language model instance to use."""
|
"""The language model instance to use."""
|
||||||
llm: BaseLanguageModel = FakeLLM(response="")
|
llm: BaseLanguageModel = FakeLLM(response="")
|
||||||
|
parameters: Dict[str, Any] = {}
|
||||||
|
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||||
messages = prompts[0].to_messages()
|
messages = prompts[0].to_messages()
|
||||||
prompt_messages = to_prompt_messages(messages)
|
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||||
result = self.model_instance.run(
|
|
||||||
messages=prompt_messages,
|
model_instance = ModelInstance(
|
||||||
stop=stop
|
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||||
|
model=self.model_config.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
stream=False,
|
||||||
|
stop=stop,
|
||||||
|
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
|
||||||
|
model_parameters=self.parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
generations = [
|
generations = [
|
||||||
[Generation(text=result.content)]
|
[Generation(text=result.message.content)]
|
||||||
]
|
]
|
||||||
|
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
|
@ -1,501 +0,0 @@
|
||||||
import concurrent
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Optional, List, Union, Tuple
|
|
||||||
|
|
||||||
from flask import current_app, Flask
|
|
||||||
from requests.exceptions import ChunkedEncodingError
|
|
||||||
|
|
||||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
|
||||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
|
||||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
|
||||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
|
||||||
ConversationTaskInterruptException
|
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
|
||||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
|
||||||
from core.file.file_obj import FileObj
|
|
||||||
from core.index.vector_index.vector_index import VectorIndex
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
|
||||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
|
||||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
|
||||||
from core.model_providers.model_factory import ModelFactory
|
|
||||||
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
|
||||||
from models.dataset import Dataset
|
|
||||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
|
||||||
from core.moderation.base import ModerationException, ModerationAction
|
|
||||||
from core.moderation.factory import ModerationFactory
|
|
||||||
from services.annotation_service import AppAnnotationService
|
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
|
||||||
|
|
||||||
|
|
||||||
class Completion:
|
|
||||||
@classmethod
|
|
||||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
|
||||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
|
||||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
|
||||||
auto_generate_name: bool = True, from_source: str = 'console'):
|
|
||||||
"""
|
|
||||||
errors: ProviderTokenNotInitError
|
|
||||||
"""
|
|
||||||
query = PromptTemplateParser.remove_template_variables(query)
|
|
||||||
|
|
||||||
memory = None
|
|
||||||
if conversation:
|
|
||||||
# get memory of conversation (read-only)
|
|
||||||
memory = cls.get_memory_from_conversation(
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
conversation=conversation,
|
|
||||||
return_messages=False
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = conversation.inputs
|
|
||||||
|
|
||||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
model_config=app_model_config.model_dict,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
|
||||||
|
|
||||||
conversation_message_task = ConversationMessageTask(
|
|
||||||
task_id=task_id,
|
|
||||||
app=app,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
user=user,
|
|
||||||
conversation=conversation,
|
|
||||||
is_override=is_override,
|
|
||||||
inputs=inputs,
|
|
||||||
query=query,
|
|
||||||
files=files,
|
|
||||||
streaming=streaming,
|
|
||||||
model_instance=final_model_instance,
|
|
||||||
auto_generate_name=auto_generate_name
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_message_files = [file.prompt_message_file for file in files]
|
|
||||||
|
|
||||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
|
||||||
mode=app.mode,
|
|
||||||
model_instance=final_model_instance,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
query=query,
|
|
||||||
inputs=inputs,
|
|
||||||
files=prompt_message_files
|
|
||||||
)
|
|
||||||
|
|
||||||
# init orchestrator rule parser
|
|
||||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
app_model_config=app_model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# process sensitive_word_avoidance
|
|
||||||
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
|
|
||||||
except ModerationException as e:
|
|
||||||
cls.run_final_llm(
|
|
||||||
model_instance=final_model_instance,
|
|
||||||
mode=app.mode,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
query=query,
|
|
||||||
inputs=inputs,
|
|
||||||
files=prompt_message_files,
|
|
||||||
agent_execute_result=None,
|
|
||||||
conversation_message_task=conversation_message_task,
|
|
||||||
memory=memory,
|
|
||||||
fake_response=str(e)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# check annotation reply
|
|
||||||
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
|
|
||||||
if annotation_reply:
|
|
||||||
return
|
|
||||||
# fill in variable inputs from external data tools if exists
|
|
||||||
external_data_tools = app_model_config.external_data_tools_list
|
|
||||||
if external_data_tools:
|
|
||||||
inputs = cls.fill_in_inputs_from_external_data_tools(
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
app_id=app.id,
|
|
||||||
external_data_tools=external_data_tools,
|
|
||||||
inputs=inputs,
|
|
||||||
query=query
|
|
||||||
)
|
|
||||||
|
|
||||||
# get agent executor
|
|
||||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
|
||||||
conversation_message_task=conversation_message_task,
|
|
||||||
memory=memory,
|
|
||||||
rest_tokens=rest_tokens_for_context_and_memory,
|
|
||||||
chain_callback=chain_callback,
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
retriever_from=retriever_from
|
|
||||||
)
|
|
||||||
|
|
||||||
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
|
|
||||||
|
|
||||||
# run agent executor
|
|
||||||
agent_execute_result = None
|
|
||||||
if query_for_agent and agent_executor:
|
|
||||||
should_use_agent = agent_executor.should_use_agent(query_for_agent)
|
|
||||||
if should_use_agent:
|
|
||||||
agent_execute_result = agent_executor.run(query_for_agent)
|
|
||||||
|
|
||||||
# When no extra pre prompt is specified,
|
|
||||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
|
||||||
fake_response = None
|
|
||||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
|
||||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
|
|
||||||
PlanningStrategy.REACT_ROUTER]:
|
|
||||||
fake_response = agent_execute_result.output
|
|
||||||
|
|
||||||
# run the final llm
|
|
||||||
cls.run_final_llm(
|
|
||||||
model_instance=final_model_instance,
|
|
||||||
mode=app.mode,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
query=query,
|
|
||||||
inputs=inputs,
|
|
||||||
files=prompt_message_files,
|
|
||||||
agent_execute_result=agent_execute_result,
|
|
||||||
conversation_message_task=conversation_message_task,
|
|
||||||
memory=memory,
|
|
||||||
fake_response=fake_response
|
|
||||||
)
|
|
||||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
|
||||||
return
|
|
||||||
except ChunkedEncodingError as e:
|
|
||||||
# Interrupt by LLM (like OpenAI), handle it.
|
|
||||||
logging.warning(f'ChunkedEncodingError: {e}')
|
|
||||||
return
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
|
|
||||||
query: str):
|
|
||||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
|
||||||
return inputs, query
|
|
||||||
|
|
||||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
|
||||||
|
|
||||||
moderation = ModerationFactory(type, app_id, tenant_id,
|
|
||||||
app_model_config.sensitive_word_avoidance_dict['config'])
|
|
||||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
|
||||||
|
|
||||||
if not moderation_result.flagged:
|
|
||||||
return inputs, query
|
|
||||||
|
|
||||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
|
||||||
raise ModerationException(moderation_result.preset_response)
|
|
||||||
elif moderation_result.action == ModerationAction.OVERRIDED:
|
|
||||||
inputs = moderation_result.inputs
|
|
||||||
query = moderation_result.query
|
|
||||||
|
|
||||||
return inputs, query
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
|
|
||||||
inputs: dict, query: str) -> dict:
|
|
||||||
"""
|
|
||||||
Fill in variable inputs from external data tools if exists.
|
|
||||||
|
|
||||||
:param tenant_id: workspace id
|
|
||||||
:param app_id: app id
|
|
||||||
:param external_data_tools: external data tools configs
|
|
||||||
:param inputs: the inputs
|
|
||||||
:param query: the query
|
|
||||||
:return: the filled inputs
|
|
||||||
"""
|
|
||||||
# Group tools by type and config
|
|
||||||
grouped_tools = {}
|
|
||||||
for tool in external_data_tools:
|
|
||||||
if not tool.get("enabled"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
|
|
||||||
grouped_tools.setdefault(tool_key, []).append(tool)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
with ThreadPoolExecutor() as executor:
|
|
||||||
futures = {}
|
|
||||||
for tool in external_data_tools:
|
|
||||||
if not tool.get("enabled"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
future = executor.submit(
|
|
||||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
|
|
||||||
inputs, query
|
|
||||||
)
|
|
||||||
|
|
||||||
futures[future] = tool
|
|
||||||
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
|
||||||
tool_variable, result = future.result()
|
|
||||||
results[tool_variable] = result
|
|
||||||
|
|
||||||
inputs.update(results)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
|
|
||||||
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
|
|
||||||
with flask_app.app_context():
|
|
||||||
tool_variable = external_data_tool.get("variable")
|
|
||||||
tool_type = external_data_tool.get("type")
|
|
||||||
tool_config = external_data_tool.get("config")
|
|
||||||
|
|
||||||
external_data_tool_factory = ExternalDataToolFactory(
|
|
||||||
name=tool_type,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
app_id=app_id,
|
|
||||||
variable=tool_variable,
|
|
||||||
config=tool_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# query external data tool
|
|
||||||
result = external_data_tool_factory.query(
|
|
||||||
inputs=inputs,
|
|
||||||
query=query
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool_variable, result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
|
||||||
if app.mode != 'completion':
|
|
||||||
return query
|
|
||||||
|
|
||||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
|
||||||
inputs: dict,
|
|
||||||
files: List[PromptMessageFile],
|
|
||||||
agent_execute_result: Optional[AgentExecuteResult],
|
|
||||||
conversation_message_task: ConversationMessageTask,
|
|
||||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
|
||||||
fake_response: Optional[str]):
|
|
||||||
prompt_transform = PromptTransform()
|
|
||||||
|
|
||||||
# get llm prompt
|
|
||||||
if app_model_config.prompt_type == 'simple':
|
|
||||||
prompt_messages, stop_words = prompt_transform.get_prompt(
|
|
||||||
app_mode=mode,
|
|
||||||
pre_prompt=app_model_config.pre_prompt,
|
|
||||||
inputs=inputs,
|
|
||||||
query=query,
|
|
||||||
files=files,
|
|
||||||
context=agent_execute_result.output if agent_execute_result else None,
|
|
||||||
memory=memory,
|
|
||||||
model_instance=model_instance
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
|
||||||
app_mode=mode,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
inputs=inputs,
|
|
||||||
query=query,
|
|
||||||
files=files,
|
|
||||||
context=agent_execute_result.output if agent_execute_result else None,
|
|
||||||
memory=memory,
|
|
||||||
model_instance=model_instance
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = app_model_config.model_dict
|
|
||||||
completion_params = model_config.get("completion_params", {})
|
|
||||||
stop_words = completion_params.get("stop", [])
|
|
||||||
|
|
||||||
cls.recale_llm_max_tokens(
|
|
||||||
model_instance=model_instance,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = model_instance.run(
|
|
||||||
messages=prompt_messages,
|
|
||||||
stop=stop_words if stop_words else None,
|
|
||||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
|
||||||
fake_response=fake_response
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
|
||||||
max_token_limit: int) -> str:
|
|
||||||
"""Get memory messages."""
|
|
||||||
memory.max_token_limit = max_token_limit
|
|
||||||
memory_key = memory.memory_variables[0]
|
|
||||||
external_context = memory.load_memory_variables({})
|
|
||||||
return external_context[memory_key]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
|
|
||||||
from_source: str) -> bool:
|
|
||||||
"""Get memory messages."""
|
|
||||||
app_model_config = conversation_message_task.app_model_config
|
|
||||||
app = conversation_message_task.app
|
|
||||||
annotation_reply = app_model_config.annotation_reply_dict
|
|
||||||
if annotation_reply['enabled']:
|
|
||||||
try:
|
|
||||||
score_threshold = annotation_reply.get('score_threshold', 1)
|
|
||||||
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
|
|
||||||
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
|
|
||||||
# get embedding model
|
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
model_provider_name=embedding_provider_name,
|
|
||||||
model_name=embedding_model_name
|
|
||||||
)
|
|
||||||
embeddings = CacheEmbedding(embedding_model)
|
|
||||||
|
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
||||||
embedding_provider_name,
|
|
||||||
embedding_model_name,
|
|
||||||
'annotation'
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = Dataset(
|
|
||||||
id=app.id,
|
|
||||||
tenant_id=app.tenant_id,
|
|
||||||
indexing_technique='high_quality',
|
|
||||||
embedding_model_provider=embedding_provider_name,
|
|
||||||
embedding_model=embedding_model_name,
|
|
||||||
collection_binding_id=dataset_collection_binding.id
|
|
||||||
)
|
|
||||||
|
|
||||||
vector_index = VectorIndex(
|
|
||||||
dataset=dataset,
|
|
||||||
config=current_app.config,
|
|
||||||
embeddings=embeddings,
|
|
||||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = vector_index.search(
|
|
||||||
conversation_message_task.query,
|
|
||||||
search_type='similarity_score_threshold',
|
|
||||||
search_kwargs={
|
|
||||||
'k': 1,
|
|
||||||
'score_threshold': score_threshold,
|
|
||||||
'filter': {
|
|
||||||
'group_id': [dataset.id]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
annotation_id = documents[0].metadata['annotation_id']
|
|
||||||
score = documents[0].metadata['score']
|
|
||||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
|
||||||
if annotation:
|
|
||||||
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
|
|
||||||
# insert annotation history
|
|
||||||
AppAnnotationService.add_annotation_history(annotation.id,
|
|
||||||
app.id,
|
|
||||||
annotation.question,
|
|
||||||
annotation.content,
|
|
||||||
conversation_message_task.query,
|
|
||||||
conversation_message_task.user.id,
|
|
||||||
conversation_message_task.message.id,
|
|
||||||
from_source,
|
|
||||||
score)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f'Query annotation failed, exception: {str(e)}.')
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
|
||||||
conversation: Conversation,
|
|
||||||
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
|
|
||||||
# only for calc token in memory
|
|
||||||
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_config=app_model_config.model_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# use llm config from conversation
|
|
||||||
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
|
|
||||||
conversation=conversation,
|
|
||||||
model_instance=memory_model_instance,
|
|
||||||
max_token_limit=kwargs.get("max_token_limit", 2048),
|
|
||||||
memory_key=kwargs.get("memory_key", "chat_history"),
|
|
||||||
return_messages=kwargs.get("return_messages", True),
|
|
||||||
input_key=kwargs.get("input_key", "input"),
|
|
||||||
output_key=kwargs.get("output_key", "output"),
|
|
||||||
message_limit=kwargs.get("message_limit", 10),
|
|
||||||
)
|
|
||||||
|
|
||||||
return memory
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
|
||||||
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
|
|
||||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
|
||||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
|
||||||
|
|
||||||
if model_limited_tokens is None:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
if max_tokens is None:
|
|
||||||
max_tokens = 0
|
|
||||||
|
|
||||||
prompt_transform = PromptTransform()
|
|
||||||
|
|
||||||
# get prompt without memory and context
|
|
||||||
if app_model_config.prompt_type == 'simple':
|
|
||||||
prompt_messages, _ = prompt_transform.get_prompt(
|
|
||||||
app_mode=mode,
|
|
||||||
pre_prompt=app_model_config.pre_prompt,
|
|
||||||
inputs=inputs,
|
|
||||||
query=query,
|
|
||||||
files=files,
|
|
||||||
context=None,
|
|
||||||
memory=None,
|
|
||||||
model_instance=model_instance
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
|
||||||
app_mode=mode,
|
|
||||||
app_model_config=app_model_config,
|
|
||||||
inputs=inputs,
|
|
||||||
query=query,
|
|
||||||
files=files,
|
|
||||||
context=None,
|
|
||||||
memory=None,
|
|
||||||
model_instance=model_instance
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
|
||||||
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
|
||||||
if rest_tokens < 0:
|
|
||||||
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
|
||||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
|
||||||
|
|
||||||
return rest_tokens
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
|
|
||||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
|
||||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
|
||||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
|
||||||
|
|
||||||
if model_limited_tokens is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if max_tokens is None:
|
|
||||||
max_tokens = 0
|
|
||||||
|
|
||||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
|
||||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
|
||||||
|
|
||||||
# update model instance max tokens
|
|
||||||
model_kwargs = model_instance.get_model_kwargs()
|
|
||||||
model_kwargs.max_tokens = max_tokens
|
|
||||||
model_instance.set_model_kwargs(model_kwargs)
|
|
|
@ -1,517 +0,0 @@
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Optional, Union, List
|
|
||||||
|
|
||||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
|
||||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
|
||||||
from core.callback_handler.entity.llm_message import LLMMessage
|
|
||||||
from core.callback_handler.entity.chain_result import ChainResult
|
|
||||||
from core.file.file_obj import FileObj
|
|
||||||
from core.model_providers.model_factory import ModelFactory
|
|
||||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.prompt.prompt_builder import PromptBuilder
|
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
|
||||||
from events.message_event import message_was_created
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
from models.dataset import DatasetQuery
|
|
||||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
|
||||||
MessageChain, DatasetRetrieverResource, MessageFile
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationMessageTask:
|
|
||||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
|
||||||
inputs: dict, query: str, files: List[FileObj], streaming: bool,
|
|
||||||
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
|
|
||||||
auto_generate_name: bool = True):
|
|
||||||
self.start_at = time.perf_counter()
|
|
||||||
|
|
||||||
self.task_id = task_id
|
|
||||||
|
|
||||||
self.app = app
|
|
||||||
self.tenant_id = app.tenant_id
|
|
||||||
self.app_model_config = app_model_config
|
|
||||||
self.is_override = is_override
|
|
||||||
|
|
||||||
self.user = user
|
|
||||||
self.inputs = inputs
|
|
||||||
self.query = query
|
|
||||||
self.files = files
|
|
||||||
self.streaming = streaming
|
|
||||||
|
|
||||||
self.conversation = conversation
|
|
||||||
self.is_new_conversation = False
|
|
||||||
|
|
||||||
self.model_instance = model_instance
|
|
||||||
|
|
||||||
self.message = None
|
|
||||||
|
|
||||||
self.retriever_resource = None
|
|
||||||
self.auto_generate_name = auto_generate_name
|
|
||||||
|
|
||||||
self.model_dict = self.app_model_config.model_dict
|
|
||||||
self.provider_name = self.model_dict.get('provider')
|
|
||||||
self.model_name = self.model_dict.get('name')
|
|
||||||
self.mode = app.mode
|
|
||||||
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
self._pub_handler = PubHandler(
|
|
||||||
user=self.user,
|
|
||||||
task_id=self.task_id,
|
|
||||||
message=self.message,
|
|
||||||
conversation=self.conversation,
|
|
||||||
chain_pub=False, # disabled currently
|
|
||||||
agent_thought_pub=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
|
|
||||||
override_model_configs = None
|
|
||||||
if self.is_override:
|
|
||||||
override_model_configs = self.app_model_config.to_dict()
|
|
||||||
|
|
||||||
introduction = ''
|
|
||||||
system_instruction = ''
|
|
||||||
system_instruction_tokens = 0
|
|
||||||
if self.mode == 'chat':
|
|
||||||
introduction = self.app_model_config.opening_statement
|
|
||||||
if introduction:
|
|
||||||
prompt_template = PromptTemplateParser(template=introduction)
|
|
||||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
|
|
||||||
try:
|
|
||||||
introduction = prompt_template.format(prompt_inputs)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.app_model_config.pre_prompt:
|
|
||||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
|
||||||
system_instruction = system_message.content
|
|
||||||
model_instance = ModelFactory.get_text_generation_model(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
model_provider_name=self.provider_name,
|
|
||||||
model_name=self.model_name
|
|
||||||
)
|
|
||||||
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
|
|
||||||
|
|
||||||
if not self.conversation:
|
|
||||||
self.is_new_conversation = True
|
|
||||||
self.conversation = Conversation(
|
|
||||||
app_id=self.app.id,
|
|
||||||
app_model_config_id=self.app_model_config.id,
|
|
||||||
model_provider=self.provider_name,
|
|
||||||
model_id=self.model_name,
|
|
||||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
|
||||||
mode=self.mode,
|
|
||||||
name='New conversation',
|
|
||||||
inputs=self.inputs,
|
|
||||||
introduction=introduction,
|
|
||||||
system_instruction=system_instruction,
|
|
||||||
system_instruction_tokens=system_instruction_tokens,
|
|
||||||
status='normal',
|
|
||||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
|
||||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
|
||||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(self.conversation)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
self.message = Message(
|
|
||||||
app_id=self.app.id,
|
|
||||||
model_provider=self.provider_name,
|
|
||||||
model_id=self.model_name,
|
|
||||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
|
||||||
conversation_id=self.conversation.id,
|
|
||||||
inputs=self.inputs,
|
|
||||||
query=self.query,
|
|
||||||
message="",
|
|
||||||
message_tokens=0,
|
|
||||||
message_unit_price=0,
|
|
||||||
message_price_unit=0,
|
|
||||||
answer="",
|
|
||||||
answer_tokens=0,
|
|
||||||
answer_unit_price=0,
|
|
||||||
answer_price_unit=0,
|
|
||||||
provider_response_latency=0,
|
|
||||||
total_price=0,
|
|
||||||
currency=self.model_instance.get_currency(),
|
|
||||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
|
||||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
|
||||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
|
||||||
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(self.message)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
for file in self.files:
|
|
||||||
message_file = MessageFile(
|
|
||||||
message_id=self.message.id,
|
|
||||||
type=file.type.value,
|
|
||||||
transfer_method=file.transfer_method.value,
|
|
||||||
url=file.url,
|
|
||||||
upload_file_id=file.upload_file_id,
|
|
||||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
|
||||||
created_by=self.user.id
|
|
||||||
)
|
|
||||||
db.session.add(message_file)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def append_message_text(self, text: str):
|
|
||||||
if text is not None:
|
|
||||||
self._pub_handler.pub_text(text)
|
|
||||||
|
|
||||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
|
||||||
message_tokens = llm_message.prompt_tokens
|
|
||||||
answer_tokens = llm_message.completion_tokens
|
|
||||||
|
|
||||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
|
|
||||||
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
|
|
||||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
|
||||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
|
||||||
|
|
||||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
|
|
||||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
|
||||||
total_price = message_total_price + answer_total_price
|
|
||||||
|
|
||||||
self.message.message = llm_message.prompt
|
|
||||||
self.message.message_tokens = message_tokens
|
|
||||||
self.message.message_unit_price = message_unit_price
|
|
||||||
self.message.message_price_unit = message_price_unit
|
|
||||||
self.message.answer = PromptTemplateParser.remove_template_variables(
|
|
||||||
llm_message.completion.strip()) if llm_message.completion else ''
|
|
||||||
self.message.answer_tokens = answer_tokens
|
|
||||||
self.message.answer_unit_price = answer_unit_price
|
|
||||||
self.message.answer_price_unit = answer_price_unit
|
|
||||||
self.message.provider_response_latency = time.perf_counter() - self.start_at
|
|
||||||
self.message.total_price = total_price
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
message_was_created.send(
|
|
||||||
self.message,
|
|
||||||
conversation=self.conversation,
|
|
||||||
is_first_message=self.is_new_conversation,
|
|
||||||
auto_generate_name=self.auto_generate_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not by_stopped:
|
|
||||||
self.end()
|
|
||||||
|
|
||||||
def init_chain(self, chain_result: ChainResult):
|
|
||||||
message_chain = MessageChain(
|
|
||||||
message_id=self.message.id,
|
|
||||||
type=chain_result.type,
|
|
||||||
input=json.dumps(chain_result.prompt),
|
|
||||||
output=''
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(message_chain)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return message_chain
|
|
||||||
|
|
||||||
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
|
|
||||||
message_chain.output = json.dumps(chain_result.completion)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
self._pub_handler.pub_chain(message_chain)
|
|
||||||
|
|
||||||
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
|
|
||||||
message_agent_thought = MessageAgentThought(
|
|
||||||
message_id=self.message.id,
|
|
||||||
message_chain_id=message_chain.id,
|
|
||||||
position=agent_loop.position,
|
|
||||||
thought=agent_loop.thought,
|
|
||||||
tool=agent_loop.tool_name,
|
|
||||||
tool_input=agent_loop.tool_input,
|
|
||||||
message=agent_loop.prompt,
|
|
||||||
message_price_unit=0,
|
|
||||||
answer=agent_loop.completion,
|
|
||||||
answer_price_unit=0,
|
|
||||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
|
||||||
created_by=self.user.id
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(message_agent_thought)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
|
||||||
|
|
||||||
return message_agent_thought
|
|
||||||
|
|
||||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
|
||||||
agent_loop: AgentLoop):
|
|
||||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
|
|
||||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
|
|
||||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
|
||||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
|
||||||
|
|
||||||
loop_message_tokens = agent_loop.prompt_tokens
|
|
||||||
loop_answer_tokens = agent_loop.completion_tokens
|
|
||||||
|
|
||||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
|
|
||||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
|
||||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
|
||||||
|
|
||||||
message_agent_thought.observation = agent_loop.tool_output
|
|
||||||
message_agent_thought.tool_process_data = '' # currently not support
|
|
||||||
message_agent_thought.message_token = loop_message_tokens
|
|
||||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
|
||||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
|
||||||
message_agent_thought.answer_token = loop_answer_tokens
|
|
||||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
|
||||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
|
||||||
message_agent_thought.latency = agent_loop.latency
|
|
||||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
|
||||||
message_agent_thought.total_price = loop_total_price
|
|
||||||
message_agent_thought.currency = agent_model_instance.get_currency()
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
|
||||||
dataset_query = DatasetQuery(
|
|
||||||
dataset_id=dataset_query_obj.dataset_id,
|
|
||||||
content=dataset_query_obj.query,
|
|
||||||
source='app',
|
|
||||||
source_app_id=self.app.id,
|
|
||||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
|
||||||
created_by=self.user.id
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(dataset_query)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def on_dataset_query_finish(self, resource: List):
|
|
||||||
if resource and len(resource) > 0:
|
|
||||||
for item in resource:
|
|
||||||
dataset_retriever_resource = DatasetRetrieverResource(
|
|
||||||
message_id=self.message.id,
|
|
||||||
position=item.get('position'),
|
|
||||||
dataset_id=item.get('dataset_id'),
|
|
||||||
dataset_name=item.get('dataset_name'),
|
|
||||||
document_id=item.get('document_id'),
|
|
||||||
document_name=item.get('document_name'),
|
|
||||||
data_source_type=item.get('data_source_type'),
|
|
||||||
segment_id=item.get('segment_id'),
|
|
||||||
score=item.get('score') if 'score' in item else None,
|
|
||||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
|
||||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
|
||||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
|
||||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
|
||||||
content=item.get('content'),
|
|
||||||
retriever_from=item.get('retriever_from'),
|
|
||||||
created_by=self.user.id
|
|
||||||
)
|
|
||||||
db.session.add(dataset_retriever_resource)
|
|
||||||
db.session.commit()
|
|
||||||
self.retriever_resource = resource
|
|
||||||
|
|
||||||
def on_message_replace(self, text: str):
|
|
||||||
if text is not None:
|
|
||||||
self._pub_handler.pub_message_replace(text)
|
|
||||||
|
|
||||||
def message_end(self):
|
|
||||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
|
||||||
self._pub_handler.pub_end()
|
|
||||||
|
|
||||||
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
|
|
||||||
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
|
|
||||||
self._pub_handler.pub_end()
|
|
||||||
|
|
||||||
|
|
||||||
class PubHandler:
|
|
||||||
def __init__(self, user: Union[Account, EndUser], task_id: str,
|
|
||||||
message: Message, conversation: Conversation,
|
|
||||||
chain_pub: bool = False, agent_thought_pub: bool = False):
|
|
||||||
self._channel = PubHandler.generate_channel_name(user, task_id)
|
|
||||||
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
|
|
||||||
|
|
||||||
self._task_id = task_id
|
|
||||||
self._message = message
|
|
||||||
self._conversation = conversation
|
|
||||||
self._chain_pub = chain_pub
|
|
||||||
self._agent_thought_pub = agent_thought_pub
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
|
|
||||||
if not user:
|
|
||||||
raise ValueError("user is required")
|
|
||||||
|
|
||||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
|
||||||
return "generate_result:{}-{}".format(user_str, task_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
|
|
||||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
|
||||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
|
||||||
|
|
||||||
def pub_text(self, text: str):
|
|
||||||
content = {
|
|
||||||
'event': 'message',
|
|
||||||
'data': {
|
|
||||||
'task_id': self._task_id,
|
|
||||||
'message_id': str(self._message.id),
|
|
||||||
'text': text,
|
|
||||||
'mode': self._conversation.mode,
|
|
||||||
'conversation_id': str(self._conversation.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
if self._is_stopped():
|
|
||||||
self.pub_end()
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
|
|
||||||
def pub_message_replace(self, text: str):
|
|
||||||
content = {
|
|
||||||
'event': 'message_replace',
|
|
||||||
'data': {
|
|
||||||
'task_id': self._task_id,
|
|
||||||
'message_id': str(self._message.id),
|
|
||||||
'text': text,
|
|
||||||
'mode': self._conversation.mode,
|
|
||||||
'conversation_id': str(self._conversation.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
if self._is_stopped():
|
|
||||||
self.pub_end()
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
|
|
||||||
def pub_chain(self, message_chain: MessageChain):
|
|
||||||
if self._chain_pub:
|
|
||||||
content = {
|
|
||||||
'event': 'chain',
|
|
||||||
'data': {
|
|
||||||
'task_id': self._task_id,
|
|
||||||
'message_id': self._message.id,
|
|
||||||
'chain_id': message_chain.id,
|
|
||||||
'type': message_chain.type,
|
|
||||||
'input': json.loads(message_chain.input),
|
|
||||||
'output': json.loads(message_chain.output),
|
|
||||||
'mode': self._conversation.mode,
|
|
||||||
'conversation_id': self._conversation.id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
if self._is_stopped():
|
|
||||||
self.pub_end()
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
|
|
||||||
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
|
|
||||||
if self._agent_thought_pub:
|
|
||||||
content = {
|
|
||||||
'event': 'agent_thought',
|
|
||||||
'data': {
|
|
||||||
'id': message_agent_thought.id,
|
|
||||||
'task_id': self._task_id,
|
|
||||||
'message_id': self._message.id,
|
|
||||||
'chain_id': message_agent_thought.message_chain_id,
|
|
||||||
'position': message_agent_thought.position,
|
|
||||||
'thought': message_agent_thought.thought,
|
|
||||||
'tool': message_agent_thought.tool,
|
|
||||||
'tool_input': message_agent_thought.tool_input,
|
|
||||||
'mode': self._conversation.mode,
|
|
||||||
'conversation_id': self._conversation.id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
if self._is_stopped():
|
|
||||||
self.pub_end()
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
|
|
||||||
def pub_message_end(self, retriever_resource: List):
|
|
||||||
content = {
|
|
||||||
'event': 'message_end',
|
|
||||||
'data': {
|
|
||||||
'task_id': self._task_id,
|
|
||||||
'message_id': self._message.id,
|
|
||||||
'mode': self._conversation.mode,
|
|
||||||
'conversation_id': self._conversation.id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if retriever_resource:
|
|
||||||
content['data']['retriever_resources'] = retriever_resource
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
if self._is_stopped():
|
|
||||||
self.pub_end()
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
|
|
||||||
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
|
|
||||||
content = {
|
|
||||||
'event': 'annotation',
|
|
||||||
'data': {
|
|
||||||
'task_id': self._task_id,
|
|
||||||
'message_id': self._message.id,
|
|
||||||
'mode': self._conversation.mode,
|
|
||||||
'conversation_id': self._conversation.id,
|
|
||||||
'text': text,
|
|
||||||
'annotation_id': annotation_id,
|
|
||||||
'annotation_author_name': annotation_author_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self._message.answer = text
|
|
||||||
self._message.provider_response_latency = time.perf_counter() - start_at
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
if self._is_stopped():
|
|
||||||
self.pub_end()
|
|
||||||
raise ConversationTaskStoppedException()
|
|
||||||
|
|
||||||
def pub_end(self):
|
|
||||||
content = {
|
|
||||||
'event': 'end',
|
|
||||||
}
|
|
||||||
|
|
||||||
redis_client.publish(self._channel, json.dumps(content))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
|
|
||||||
content = {
|
|
||||||
'error': type(e).__name__,
|
|
||||||
'description': e.description if getattr(e, 'description', None) is not None else str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
channel = cls.generate_channel_name(user, task_id)
|
|
||||||
redis_client.publish(channel, json.dumps(content))
|
|
||||||
|
|
||||||
def _is_stopped(self):
|
|
||||||
return redis_client.get(self._stopped_cache_key) is not None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ping(cls, user: Union[Account, EndUser], task_id: str):
|
|
||||||
content = {
|
|
||||||
'event': 'ping'
|
|
||||||
}
|
|
||||||
|
|
||||||
channel = cls.generate_channel_name(user, task_id)
|
|
||||||
redis_client.publish(channel, json.dumps(content))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def stop(cls, user: Union[Account, EndUser], task_id: str):
|
|
||||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
|
||||||
redis_client.setex(stopped_cache_key, 600, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationTaskStoppedException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationTaskInterruptException(Exception):
|
|
||||||
pass
|
|
|
@ -1,9 +1,11 @@
|
||||||
from typing import Any, Dict, Optional, Sequence
|
from typing import Any, Dict, Optional, Sequence, cast
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
|
|
||||||
|
@ -69,10 +71,12 @@ class DatasetDocumentStore:
|
||||||
max_position = 0
|
max_position = 0
|
||||||
embedding_model = None
|
embedding_model = None
|
||||||
if self._dataset.indexing_technique == 'high_quality':
|
if self._dataset.indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
tenant_id=self._dataset.tenant_id,
|
tenant_id=self._dataset.tenant_id,
|
||||||
model_provider_name=self._dataset.embedding_model_provider,
|
provider=self._dataset.embedding_model_provider,
|
||||||
model_name=self._dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=self._dataset.embedding_model
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -89,7 +93,16 @@ class DatasetDocumentStore:
|
||||||
)
|
)
|
||||||
|
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
|
if embedding_model:
|
||||||
|
model_type_instance = embedding_model.model_type_instance
|
||||||
|
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
|
||||||
|
tokens = model_type_instance.get_num_tokens(
|
||||||
|
model=embedding_model.model,
|
||||||
|
credentials=embedding_model.credentials,
|
||||||
|
texts=[doc.page_content]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokens = 0
|
||||||
|
|
||||||
if not segment_document:
|
if not segment_document:
|
||||||
max_position += 1
|
max_position += 1
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
from core.model_manager import ModelInstance
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Embedding
|
from models.dataset import Embedding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CacheEmbedding(Embeddings):
|
class CacheEmbedding(Embeddings):
|
||||||
def __init__(self, embeddings: BaseEmbedding):
|
def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
|
||||||
self._embeddings = embeddings
|
self._model_instance = model_instance
|
||||||
|
self._user = user
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Embed search docs."""
|
"""Embed search docs."""
|
||||||
|
@ -22,7 +25,7 @@ class CacheEmbedding(Embeddings):
|
||||||
embedding_queue_indices = []
|
embedding_queue_indices = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
hash = helper.generate_text_hash(text)
|
hash = helper.generate_text_hash(text)
|
||||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||||
if embedding:
|
if embedding:
|
||||||
text_embeddings[i] = embedding.get_embedding()
|
text_embeddings[i] = embedding.get_embedding()
|
||||||
else:
|
else:
|
||||||
|
@ -30,15 +33,21 @@ class CacheEmbedding(Embeddings):
|
||||||
|
|
||||||
if embedding_queue_indices:
|
if embedding_queue_indices:
|
||||||
try:
|
try:
|
||||||
embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
|
embedding_result = self._model_instance.invoke_text_embedding(
|
||||||
|
texts=[texts[i] for i in embedding_queue_indices],
|
||||||
|
user=self._user
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_results = embedding_result.embeddings
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise self._embeddings.handle_exceptions(ex)
|
logger.error('Failed to embed documents: ', ex)
|
||||||
|
raise ex
|
||||||
|
|
||||||
for i, indice in enumerate(embedding_queue_indices):
|
for i, indice in enumerate(embedding_queue_indices):
|
||||||
hash = helper.generate_text_hash(texts[indice])
|
hash = helper.generate_text_hash(texts[indice])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||||
vector = embedding_results[i]
|
vector = embedding_results[i]
|
||||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||||
text_embeddings[indice] = normalized_embedding
|
text_embeddings[indice] = normalized_embedding
|
||||||
|
@ -58,18 +67,23 @@ class CacheEmbedding(Embeddings):
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
hash = helper.generate_text_hash(text)
|
hash = helper.generate_text_hash(text)
|
||||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||||
if embedding:
|
if embedding:
|
||||||
return embedding.get_embedding()
|
return embedding.get_embedding()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding_results = self._embeddings.client.embed_query(text)
|
embedding_result = self._model_instance.invoke_text_embedding(
|
||||||
|
texts=[text],
|
||||||
|
user=self._user
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_results = embedding_result.embeddings[0]
|
||||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise self._embeddings.handle_exceptions(ex)
|
raise ex
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||||
embedding.set_embedding(embedding_results)
|
embedding.set_embedding(embedding_results)
|
||||||
db.session.add(embedding)
|
db.session.add(embedding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -79,4 +93,3 @@ class CacheEmbedding(Embeddings):
|
||||||
logging.exception('Failed to add embedding to db')
|
logging.exception('Failed to add embedding to db')
|
||||||
|
|
||||||
return embedding_results
|
return embedding_results
|
||||||
|
|
||||||
|
|
265
api/core/entities/application_entities.py
Normal file
265
api/core/entities/application_entities.py
Normal file
|
@ -0,0 +1,265 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Any, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
|
from core.file.file_obj import FileObj
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfigEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Model Config Entity.
|
||||||
|
"""
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
model_schema: AIModelEntity
|
||||||
|
mode: str
|
||||||
|
provider_model_bundle: ProviderModelBundle
|
||||||
|
credentials: dict[str, Any] = {}
|
||||||
|
parameters: dict[str, Any] = {}
|
||||||
|
stop: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatMessageEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Advanced Chat Message Entity.
|
||||||
|
"""
|
||||||
|
text: str
|
||||||
|
role: PromptMessageRole
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatPromptTemplateEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Advanced Chat Prompt Template Entity.
|
||||||
|
"""
|
||||||
|
messages: list[AdvancedChatMessageEntity]
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Advanced Completion Prompt Template Entity.
|
||||||
|
"""
|
||||||
|
class RolePrefixEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Role Prefix Entity.
|
||||||
|
"""
|
||||||
|
user: str
|
||||||
|
assistant: str
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
role_prefix: Optional[RolePrefixEntity] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Prompt Template Entity.
|
||||||
|
"""
|
||||||
|
class PromptType(Enum):
|
||||||
|
"""
|
||||||
|
Prompt Type.
|
||||||
|
'simple', 'advanced'
|
||||||
|
"""
|
||||||
|
SIMPLE = 'simple'
|
||||||
|
ADVANCED = 'advanced'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> 'PromptType':
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f'invalid prompt type value {value}')
|
||||||
|
|
||||||
|
prompt_type: PromptType
|
||||||
|
simple_prompt_template: Optional[str] = None
|
||||||
|
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
||||||
|
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
External Data Variable Entity.
|
||||||
|
"""
|
||||||
|
variable: str
|
||||||
|
type: str
|
||||||
|
config: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Dataset Retrieve Config Entity.
|
||||||
|
"""
|
||||||
|
class RetrieveStrategy(Enum):
|
||||||
|
"""
|
||||||
|
Dataset Retrieve Strategy.
|
||||||
|
'single' or 'multiple'
|
||||||
|
"""
|
||||||
|
SINGLE = 'single'
|
||||||
|
MULTIPLE = 'multiple'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f'invalid retrieve strategy value {value}')
|
||||||
|
|
||||||
|
query_variable: Optional[str] = None # Only when app mode is completion
|
||||||
|
|
||||||
|
retrieve_strategy: RetrieveStrategy
|
||||||
|
single_strategy: Optional[str] = None # for temp
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
score_threshold: Optional[float] = None
|
||||||
|
reranking_model: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Dataset Config Entity.
|
||||||
|
"""
|
||||||
|
dataset_ids: list[str]
|
||||||
|
retrieve_config: DatasetRetrieveConfigEntity
|
||||||
|
|
||||||
|
|
||||||
|
class SensitiveWordAvoidanceEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Sensitive Word Avoidance Entity.
|
||||||
|
"""
|
||||||
|
type: str
|
||||||
|
config: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class FileUploadEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
File Upload Entity.
|
||||||
|
"""
|
||||||
|
image_config: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentToolEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Agent Tool Entity.
|
||||||
|
"""
|
||||||
|
tool_id: str
|
||||||
|
config: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Agent Entity.
|
||||||
|
"""
|
||||||
|
class Strategy(Enum):
|
||||||
|
"""
|
||||||
|
Agent Strategy.
|
||||||
|
"""
|
||||||
|
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||||
|
FUNCTION_CALLING = 'function-calling'
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
strategy: Strategy
|
||||||
|
tools: list[AgentToolEntity] = []
|
||||||
|
|
||||||
|
|
||||||
|
class AppOrchestrationConfigEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
App Orchestration Config Entity.
|
||||||
|
"""
|
||||||
|
model_config: ModelConfigEntity
|
||||||
|
prompt_template: PromptTemplateEntity
|
||||||
|
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||||
|
agent: Optional[AgentEntity] = None
|
||||||
|
|
||||||
|
# features
|
||||||
|
dataset: Optional[DatasetEntity] = None
|
||||||
|
file_upload: Optional[FileUploadEntity] = None
|
||||||
|
opening_statement: Optional[str] = None
|
||||||
|
suggested_questions_after_answer: bool = False
|
||||||
|
show_retrieve_source: bool = False
|
||||||
|
more_like_this: bool = False
|
||||||
|
speech_to_text: bool = False
|
||||||
|
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeFrom(Enum):
|
||||||
|
"""
|
||||||
|
Invoke From.
|
||||||
|
"""
|
||||||
|
SERVICE_API = 'service-api'
|
||||||
|
WEB_APP = 'web-app'
|
||||||
|
EXPLORE = 'explore'
|
||||||
|
DEBUGGER = 'debugger'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> 'InvokeFrom':
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f'invalid invoke from value {value}')
|
||||||
|
|
||||||
|
def to_source(self) -> str:
|
||||||
|
"""
|
||||||
|
Get source of invoke from.
|
||||||
|
|
||||||
|
:return: source
|
||||||
|
"""
|
||||||
|
if self == InvokeFrom.WEB_APP:
|
||||||
|
return 'web_app'
|
||||||
|
elif self == InvokeFrom.DEBUGGER:
|
||||||
|
return 'dev'
|
||||||
|
elif self == InvokeFrom.EXPLORE:
|
||||||
|
return 'explore_app'
|
||||||
|
elif self == InvokeFrom.SERVICE_API:
|
||||||
|
return 'api'
|
||||||
|
|
||||||
|
return 'dev'
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationGenerateEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Application Generate Entity.
|
||||||
|
"""
|
||||||
|
task_id: str
|
||||||
|
tenant_id: str
|
||||||
|
|
||||||
|
app_id: str
|
||||||
|
app_model_config_id: str
|
||||||
|
# for save
|
||||||
|
app_model_config_dict: dict
|
||||||
|
app_model_config_override: bool
|
||||||
|
|
||||||
|
# Converted from app_model_config to Entity object, or directly covered by external input
|
||||||
|
app_orchestration_config_entity: AppOrchestrationConfigEntity
|
||||||
|
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
inputs: dict[str, str]
|
||||||
|
query: Optional[str] = None
|
||||||
|
files: list[FileObj] = []
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
# extras
|
||||||
|
stream: bool
|
||||||
|
invoke_from: InvokeFrom
|
||||||
|
|
||||||
|
# extra parameters, like: auto_generate_conversation_name
|
||||||
|
extras: dict[str, Any] = {}
|
128
api/core/entities/message_entities.py
Normal file
128
api/core/entities/message_entities.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
import enum
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \
|
||||||
|
ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage
|
||||||
|
|
||||||
|
|
||||||
|
class PromptMessageFileType(enum.Enum):
|
||||||
|
IMAGE = 'image'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in PromptMessageFileType:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class PromptMessageFile(BaseModel):
|
||||||
|
type: PromptMessageFileType
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePromptMessageFile(PromptMessageFile):
|
||||||
|
class DETAIL(enum.Enum):
|
||||||
|
LOW = 'low'
|
||||||
|
HIGH = 'high'
|
||||||
|
|
||||||
|
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||||
|
detail: DETAIL = DETAIL.LOW
|
||||||
|
|
||||||
|
|
||||||
|
class LCHumanMessageWithFiles(HumanMessage):
|
||||||
|
# content: Union[str, List[Union[str, Dict]]]
|
||||||
|
content: str
|
||||||
|
files: list[PromptMessageFile]
|
||||||
|
|
||||||
|
|
||||||
|
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
|
||||||
|
prompt_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
if isinstance(message, LCHumanMessageWithFiles):
|
||||||
|
file_prompt_message_contents = []
|
||||||
|
for file in message.files:
|
||||||
|
if file.type == PromptMessageFileType.IMAGE:
|
||||||
|
file = cast(ImagePromptMessageFile, file)
|
||||||
|
file_prompt_message_contents.append(ImagePromptMessageContent(
|
||||||
|
data=file.data,
|
||||||
|
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||||
|
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
))
|
||||||
|
|
||||||
|
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
|
||||||
|
prompt_message_contents.extend(file_prompt_message_contents)
|
||||||
|
|
||||||
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=message.content))
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_kwargs = {
|
||||||
|
'content': message.content
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'function_call' in message.additional_kwargs:
|
||||||
|
message_kwargs['tool_calls'] = [
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id=message.additional_kwargs['function_call']['id'],
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=message.additional_kwargs['function_call']['name'],
|
||||||
|
arguments=message.additional_kwargs['function_call']['arguments']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt_messages.append(AssistantPromptMessage(**message_kwargs))
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
prompt_messages.append(SystemPromptMessage(content=message.content))
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
|
||||||
|
messages = []
|
||||||
|
for prompt_message in prompt_messages:
|
||||||
|
if isinstance(prompt_message, UserPromptMessage):
|
||||||
|
if isinstance(prompt_message.content, str):
|
||||||
|
messages.append(HumanMessage(content=prompt_message.content))
|
||||||
|
else:
|
||||||
|
message_contents = []
|
||||||
|
for content in prompt_message.content:
|
||||||
|
if isinstance(content, TextPromptMessageContent):
|
||||||
|
message_contents.append(content.data)
|
||||||
|
elif isinstance(content, ImagePromptMessageContent):
|
||||||
|
message_contents.append({
|
||||||
|
'type': 'image',
|
||||||
|
'data': content.data,
|
||||||
|
'detail': content.detail.value
|
||||||
|
})
|
||||||
|
|
||||||
|
messages.append(HumanMessage(content=message_contents))
|
||||||
|
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||||
|
message_kwargs = {
|
||||||
|
'content': prompt_message.content
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt_message.tool_calls:
|
||||||
|
message_kwargs['additional_kwargs'] = {
|
||||||
|
'function_call': {
|
||||||
|
'id': prompt_message.tool_calls[0].id,
|
||||||
|
'name': prompt_message.tool_calls[0].function.name,
|
||||||
|
'arguments': prompt_message.tool_calls[0].function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.append(AIMessage(**message_kwargs))
|
||||||
|
elif isinstance(prompt_message, SystemPromptMessage):
|
||||||
|
messages.append(SystemMessage(content=prompt_message.content))
|
||||||
|
elif isinstance(prompt_message, ToolPromptMessage):
|
||||||
|
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
|
||||||
|
|
||||||
|
return messages
|
71
api/core/entities/model_entities.py
Normal file
71
api/core/entities/model_entities.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import ProviderModel, ModelType
|
||||||
|
from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStatus(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for model status.
|
||||||
|
"""
|
||||||
|
ACTIVE = "active"
|
||||||
|
NO_CONFIGURE = "no-configure"
|
||||||
|
QUOTA_EXCEEDED = "quota-exceeded"
|
||||||
|
NO_PERMISSION = "no-permission"
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModelProviderEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Simple provider.
|
||||||
|
"""
|
||||||
|
provider: str
|
||||||
|
label: I18nObject
|
||||||
|
icon_small: Optional[I18nObject] = None
|
||||||
|
icon_large: Optional[I18nObject] = None
|
||||||
|
supported_model_types: list[ModelType]
|
||||||
|
|
||||||
|
def __init__(self, provider_entity: ProviderEntity) -> None:
|
||||||
|
"""
|
||||||
|
Init simple provider.
|
||||||
|
|
||||||
|
:param provider_entity: provider entity
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
provider=provider_entity.provider,
|
||||||
|
label=provider_entity.label,
|
||||||
|
icon_small=provider_entity.icon_small,
|
||||||
|
icon_large=provider_entity.icon_large,
|
||||||
|
supported_model_types=provider_entity.supported_model_types
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithProviderEntity(ProviderModel):
|
||||||
|
"""
|
||||||
|
Model with provider entity.
|
||||||
|
"""
|
||||||
|
provider: SimpleModelProviderEntity
|
||||||
|
status: ModelStatus
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultModelProviderEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Default model provider entity.
|
||||||
|
"""
|
||||||
|
provider: str
|
||||||
|
label: I18nObject
|
||||||
|
icon_small: Optional[I18nObject] = None
|
||||||
|
icon_large: Optional[I18nObject] = None
|
||||||
|
supported_model_types: list[ModelType]
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultModelEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Default model entity.
|
||||||
|
"""
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
provider: DefaultModelProviderEntity
|
657
api/core/entities/provider_configuration.py
Normal file
657
api/core/entities/provider_configuration.py
Normal file
|
@ -0,0 +1,657 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Optional, List, Dict, Tuple, Iterator
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
|
||||||
|
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||||
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
from core.model_runtime.utils import encoders
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider configuration.
|
||||||
|
"""
|
||||||
|
tenant_id: str
|
||||||
|
provider: ProviderEntity
|
||||||
|
preferred_provider_type: ProviderType
|
||||||
|
using_provider_type: ProviderType
|
||||||
|
system_configuration: SystemConfiguration
|
||||||
|
custom_configuration: CustomConfiguration
|
||||||
|
|
||||||
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get current credentials.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.using_provider_type == ProviderType.SYSTEM:
|
||||||
|
return self.system_configuration.credentials
|
||||||
|
else:
|
||||||
|
if self.custom_configuration.models:
|
||||||
|
for model_configuration in self.custom_configuration.models:
|
||||||
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||||
|
return model_configuration.credentials
|
||||||
|
|
||||||
|
if self.custom_configuration.provider:
|
||||||
|
return self.custom_configuration.provider.credentials
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||||
|
"""
|
||||||
|
Get system configuration status.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.system_configuration.enabled is False:
|
||||||
|
return SystemConfigurationStatus.UNSUPPORTED
|
||||||
|
|
||||||
|
current_quota_type = self.system_configuration.current_quota_type
|
||||||
|
current_quota_configuration = next(
|
||||||
|
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
|
||||||
|
SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||||
|
|
||||||
|
def is_custom_configuration_available(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check custom configuration available.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return (self.custom_configuration.provider is not None
|
||||||
|
or len(self.custom_configuration.models) > 0)
|
||||||
|
|
||||||
|
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get custom credentials.
|
||||||
|
|
||||||
|
:param obfuscated: obfuscated secret data in credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.custom_configuration.provider is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
credentials = self.custom_configuration.provider.credentials
|
||||||
|
if not obfuscated:
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
# Obfuscate credentials
|
||||||
|
return self._obfuscated_credentials(
|
||||||
|
credentials=credentials,
|
||||||
|
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||||
|
if self.provider.provider_credential_schema else []
|
||||||
|
)
|
||||||
|
|
||||||
|
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
||||||
|
"""
|
||||||
|
Validate custom credentials.
|
||||||
|
:param credentials: provider credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get provider
|
||||||
|
provider_record = db.session.query(Provider) \
|
||||||
|
.filter(
|
||||||
|
Provider.tenant_id == self.tenant_id,
|
||||||
|
Provider.provider_name == self.provider.provider,
|
||||||
|
Provider.provider_type == ProviderType.CUSTOM.value
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# Get provider credential secret variables
|
||||||
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
|
self.provider.provider_credential_schema.credential_form_schemas
|
||||||
|
if self.provider.provider_credential_schema else []
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_record:
|
||||||
|
try:
|
||||||
|
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||||
|
except JSONDecodeError:
|
||||||
|
original_credentials = {}
|
||||||
|
|
||||||
|
# encrypt credentials
|
||||||
|
for key, value in credentials.items():
|
||||||
|
if key in provider_credential_secret_variables:
|
||||||
|
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||||
|
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||||
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||||
|
|
||||||
|
model_provider_factory.provider_credentials_validate(
|
||||||
|
self.provider.provider,
|
||||||
|
credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in credentials.items():
|
||||||
|
if key in provider_credential_secret_variables:
|
||||||
|
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||||
|
|
||||||
|
return provider_record, credentials
|
||||||
|
|
||||||
|
def add_or_update_custom_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Add or update custom provider credentials.
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# validate custom provider config
|
||||||
|
provider_record, credentials = self.custom_credentials_validate(credentials)
|
||||||
|
|
||||||
|
# save provider
|
||||||
|
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
||||||
|
if provider_record:
|
||||||
|
provider_record.encrypted_config = json.dumps(credentials)
|
||||||
|
provider_record.is_valid = True
|
||||||
|
provider_record.updated_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
provider_record = Provider(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider_name=self.provider.provider,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=json.dumps(credentials),
|
||||||
|
is_valid=True
|
||||||
|
)
|
||||||
|
db.session.add(provider_record)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
||||||
|
|
||||||
|
def delete_custom_credentials(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete custom provider credentials.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get provider
|
||||||
|
provider_record = db.session.query(Provider) \
|
||||||
|
.filter(
|
||||||
|
Provider.tenant_id == self.tenant_id,
|
||||||
|
Provider.provider_name == self.provider.provider,
|
||||||
|
Provider.provider_type == ProviderType.CUSTOM.value
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# delete provider
|
||||||
|
if provider_record:
|
||||||
|
self.switch_preferred_provider_type(ProviderType.SYSTEM)
|
||||||
|
|
||||||
|
db.session.delete(provider_record)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||||
|
-> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get custom model credentials.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:param obfuscated: obfuscated secret data in credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not self.custom_configuration.models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for model_configuration in self.custom_configuration.models:
|
||||||
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||||
|
credentials = model_configuration.credentials
|
||||||
|
if not obfuscated:
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
# Obfuscate credentials
|
||||||
|
return self._obfuscated_credentials(
|
||||||
|
credentials=credentials,
|
||||||
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||||
|
if self.provider.model_credential_schema else []
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||||
|
-> Tuple[ProviderModel, dict]:
|
||||||
|
"""
|
||||||
|
Validate custom model credentials.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get provider model
|
||||||
|
provider_model_record = db.session.query(ProviderModel) \
|
||||||
|
.filter(
|
||||||
|
ProviderModel.tenant_id == self.tenant_id,
|
||||||
|
ProviderModel.provider_name == self.provider.provider,
|
||||||
|
ProviderModel.model_name == model,
|
||||||
|
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# Get provider credential secret variables
|
||||||
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
|
self.provider.model_credential_schema.credential_form_schemas
|
||||||
|
if self.provider.model_credential_schema else []
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_model_record:
|
||||||
|
try:
|
||||||
|
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||||
|
except JSONDecodeError:
|
||||||
|
original_credentials = {}
|
||||||
|
|
||||||
|
# decrypt credentials
|
||||||
|
for key, value in credentials.items():
|
||||||
|
if key in provider_credential_secret_variables:
|
||||||
|
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||||
|
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||||
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||||
|
|
||||||
|
model_provider_factory.model_credentials_validate(
|
||||||
|
provider=self.provider.provider,
|
||||||
|
model_type=model_type,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
model_schema = (
|
||||||
|
model_provider_factory.get_provider_instance(self.provider.provider)
|
||||||
|
.get_model_instance(model_type)._get_customizable_model_schema(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_schema:
|
||||||
|
credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
|
||||||
|
|
||||||
|
for key, value in credentials.items():
|
||||||
|
if key in provider_credential_secret_variables:
|
||||||
|
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||||
|
|
||||||
|
return provider_model_record, credentials
|
||||||
|
|
||||||
|
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Add or update custom model credentials.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# validate custom model config
|
||||||
|
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
|
||||||
|
|
||||||
|
# save provider model
|
||||||
|
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
||||||
|
if provider_model_record:
|
||||||
|
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||||
|
provider_model_record.is_valid = True
|
||||||
|
provider_model_record.updated_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
provider_model_record = ProviderModel(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider_name=self.provider.provider,
|
||||||
|
model_name=model,
|
||||||
|
model_type=model_type.to_origin_model_type(),
|
||||||
|
encrypted_config=json.dumps(credentials),
|
||||||
|
is_valid=True
|
||||||
|
)
|
||||||
|
db.session.add(provider_model_record)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete custom model credentials.
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get provider model
|
||||||
|
provider_model_record = db.session.query(ProviderModel) \
|
||||||
|
.filter(
|
||||||
|
ProviderModel.tenant_id == self.tenant_id,
|
||||||
|
ProviderModel.provider_name == self.provider.provider,
|
||||||
|
ProviderModel.model_name == model,
|
||||||
|
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# delete provider model
|
||||||
|
if provider_model_record:
|
||||||
|
db.session.delete(provider_model_record)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def get_provider_instance(self) -> ModelProvider:
|
||||||
|
"""
|
||||||
|
Get provider instance.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return model_provider_factory.get_provider_instance(self.provider.provider)
|
||||||
|
|
||||||
|
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
|
||||||
|
"""
|
||||||
|
Get current model type instance.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Get provider instance
|
||||||
|
provider_instance = self.get_provider_instance()
|
||||||
|
|
||||||
|
# Get model instance of LLM
|
||||||
|
return provider_instance.get_model_instance(model_type)
|
||||||
|
|
||||||
|
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
||||||
|
"""
|
||||||
|
Switch preferred provider type.
|
||||||
|
:param provider_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if provider_type == self.preferred_provider_type:
|
||||||
|
return
|
||||||
|
|
||||||
|
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# get preferred provider
|
||||||
|
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||||
|
.filter(
|
||||||
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||||
|
TenantPreferredModelProvider.provider_name == self.provider.provider
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if preferred_model_provider:
|
||||||
|
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||||
|
else:
|
||||||
|
preferred_model_provider = TenantPreferredModelProvider(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider_name=self.provider.provider,
|
||||||
|
preferred_provider_type=provider_type.value
|
||||||
|
)
|
||||||
|
db.session.add(preferred_model_provider)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract secret input form variables.
|
||||||
|
|
||||||
|
:param credential_form_schemas:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
secret_input_form_variables = []
|
||||||
|
for credential_form_schema in credential_form_schemas:
|
||||||
|
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||||
|
secret_input_form_variables.append(credential_form_schema.variable)
|
||||||
|
|
||||||
|
return secret_input_form_variables
|
||||||
|
|
||||||
|
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||||
|
"""
|
||||||
|
Obfuscated credentials.
|
||||||
|
|
||||||
|
:param credentials: credentials
|
||||||
|
:param credential_form_schemas: credential form schemas
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Get provider credential secret variables
|
||||||
|
credential_secret_variables = self._extract_secret_variables(
|
||||||
|
credential_form_schemas
|
||||||
|
)
|
||||||
|
|
||||||
|
# Obfuscate provider credentials
|
||||||
|
copy_credentials = credentials.copy()
|
||||||
|
for key, value in copy_credentials.items():
|
||||||
|
if key in credential_secret_variables:
|
||||||
|
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||||
|
|
||||||
|
return copy_credentials
|
||||||
|
|
||||||
|
def get_provider_model(self, model_type: ModelType,
|
||||||
|
model: str,
|
||||||
|
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
||||||
|
"""
|
||||||
|
Get provider model.
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:param only_active: return active model only
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider_models = self.get_provider_models(model_type, only_active)
|
||||||
|
|
||||||
|
for provider_model in provider_models:
|
||||||
|
if provider_model.model == model:
|
||||||
|
return provider_model
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_provider_models(self, model_type: Optional[ModelType] = None,
|
||||||
|
only_active: bool = False) -> list[ModelWithProviderEntity]:
|
||||||
|
"""
|
||||||
|
Get provider models.
|
||||||
|
:param model_type: model type
|
||||||
|
:param only_active: only active models
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider_instance = self.get_provider_instance()
|
||||||
|
|
||||||
|
model_types = []
|
||||||
|
if model_type:
|
||||||
|
model_types.append(model_type)
|
||||||
|
else:
|
||||||
|
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||||
|
|
||||||
|
if self.using_provider_type == ProviderType.SYSTEM:
|
||||||
|
provider_models = self._get_system_provider_models(
|
||||||
|
model_types=model_types,
|
||||||
|
provider_instance=provider_instance
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider_models = self._get_custom_provider_models(
|
||||||
|
model_types=model_types,
|
||||||
|
provider_instance=provider_instance
|
||||||
|
)
|
||||||
|
|
||||||
|
if only_active:
|
||||||
|
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
||||||
|
|
||||||
|
# resort provider_models
|
||||||
|
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||||
|
|
||||||
|
def _get_system_provider_models(self,
|
||||||
|
model_types: list[ModelType],
|
||||||
|
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||||
|
"""
|
||||||
|
Get system provider models.
|
||||||
|
|
||||||
|
:param model_types: model types
|
||||||
|
:param provider_instance: provider instance
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider_models = []
|
||||||
|
for model_type in model_types:
|
||||||
|
provider_models.extend(
|
||||||
|
[
|
||||||
|
ModelWithProviderEntity(
|
||||||
|
**m.dict(),
|
||||||
|
provider=SimpleModelProviderEntity(self.provider),
|
||||||
|
status=ModelStatus.ACTIVE
|
||||||
|
)
|
||||||
|
for m in provider_instance.models(model_type)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for quota_configuration in self.system_configuration.quota_configurations:
|
||||||
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
restrict_llms = quota_configuration.restrict_llms
|
||||||
|
if not restrict_llms:
|
||||||
|
break
|
||||||
|
|
||||||
|
# if llm name not in restricted llm list, remove it
|
||||||
|
for m in provider_models:
|
||||||
|
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
|
||||||
|
m.status = ModelStatus.NO_PERMISSION
|
||||||
|
elif not quota_configuration.is_valid:
|
||||||
|
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||||
|
|
||||||
|
return provider_models
|
||||||
|
|
||||||
|
def _get_custom_provider_models(self,
|
||||||
|
model_types: list[ModelType],
|
||||||
|
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||||
|
"""
|
||||||
|
Get custom provider models.
|
||||||
|
|
||||||
|
:param model_types: model types
|
||||||
|
:param provider_instance: provider instance
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider_models = []
|
||||||
|
|
||||||
|
credentials = None
|
||||||
|
if self.custom_configuration.provider:
|
||||||
|
credentials = self.custom_configuration.provider.credentials
|
||||||
|
|
||||||
|
for model_type in model_types:
|
||||||
|
if model_type not in self.provider.supported_model_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
models = provider_instance.models(model_type)
|
||||||
|
for m in models:
|
||||||
|
provider_models.append(
|
||||||
|
ModelWithProviderEntity(
|
||||||
|
**m.dict(),
|
||||||
|
provider=SimpleModelProviderEntity(self.provider),
|
||||||
|
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# custom models
|
||||||
|
for model_configuration in self.custom_configuration.models:
|
||||||
|
if model_configuration.model_type not in model_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
custom_model_schema = (
|
||||||
|
provider_instance.get_model_instance(model_configuration.model_type)
|
||||||
|
.get_customizable_model_schema_from_credentials(
|
||||||
|
model_configuration.model,
|
||||||
|
model_configuration.credentials
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not custom_model_schema:
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_models.append(
|
||||||
|
ModelWithProviderEntity(
|
||||||
|
**custom_model_schema.dict(),
|
||||||
|
provider=SimpleModelProviderEntity(self.provider),
|
||||||
|
status=ModelStatus.ACTIVE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_models
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigurations(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider configuration dict.
|
||||||
|
"""
|
||||||
|
tenant_id: str
|
||||||
|
configurations: Dict[str, ProviderConfiguration] = {}
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id)
|
||||||
|
|
||||||
|
def get_models(self,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
|
only_active: bool = False) \
|
||||||
|
-> list[ModelWithProviderEntity]:
|
||||||
|
"""
|
||||||
|
Get available models.
|
||||||
|
|
||||||
|
If preferred provider type is `system`:
|
||||||
|
Get the current **system mode** if provider supported,
|
||||||
|
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
|
||||||
|
If there is no model configured in custom mode, it is treated as no_configure.
|
||||||
|
system > custom > no_configure
|
||||||
|
|
||||||
|
If preferred provider type is `custom`:
|
||||||
|
If custom credentials are configured, it is treated as custom mode.
|
||||||
|
Otherwise, get the current **system mode** if supported,
|
||||||
|
If all system modes are not available (no quota), it is treated as no_configure.
|
||||||
|
custom > system > no_configure
|
||||||
|
|
||||||
|
If real mode is `system`, use system credentials to get models,
|
||||||
|
paid quotas > provider free quotas > system free quotas
|
||||||
|
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
|
||||||
|
If real mode is `custom`, use workspace custom credentials to get models,
|
||||||
|
include pre-defined models, custom models(manual append).
|
||||||
|
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
|
||||||
|
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
|
||||||
|
model status marked as `active` is available.
|
||||||
|
|
||||||
|
:param provider: provider name
|
||||||
|
:param model_type: model type
|
||||||
|
:param only_active: only active models
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
all_models = []
|
||||||
|
for provider_configuration in self.values():
|
||||||
|
if provider and provider_configuration.provider.provider != provider:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
|
||||||
|
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
def to_list(self) -> List[ProviderConfiguration]:
|
||||||
|
"""
|
||||||
|
Convert to list.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return list(self.values())
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.configurations[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.configurations[key] = value
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.configurations)
|
||||||
|
|
||||||
|
def values(self) -> Iterator[ProviderConfiguration]:
|
||||||
|
return self.configurations.values()
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return self.configurations.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderModelBundle(BaseModel):
|
||||||
|
"""
|
||||||
|
Provider model bundle.
|
||||||
|
"""
|
||||||
|
configuration: ProviderConfiguration
|
||||||
|
provider_instance: ModelProvider
|
||||||
|
model_type_instance: AIModel
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
67
api/core/entities/provider_entities.py
Normal file
67
api/core/entities/provider_entities.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from models.provider import ProviderQuotaType
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaUnit(Enum):
|
||||||
|
TIMES = 'times'
|
||||||
|
TOKENS = 'tokens'
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfigurationStatus(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for system configuration status.
|
||||||
|
"""
|
||||||
|
ACTIVE = 'active'
|
||||||
|
QUOTA_EXCEEDED = 'quota-exceeded'
|
||||||
|
UNSUPPORTED = 'unsupported'
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider quota configuration.
|
||||||
|
"""
|
||||||
|
quota_type: ProviderQuotaType
|
||||||
|
quota_unit: QuotaUnit
|
||||||
|
quota_limit: int
|
||||||
|
quota_used: int
|
||||||
|
is_valid: bool
|
||||||
|
restrict_llms: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider system configuration.
|
||||||
|
"""
|
||||||
|
enabled: bool
|
||||||
|
current_quota_type: Optional[ProviderQuotaType] = None
|
||||||
|
quota_configurations: list[QuotaConfiguration] = []
|
||||||
|
credentials: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CustomProviderConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider custom configuration.
|
||||||
|
"""
|
||||||
|
credentials: dict
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider custom model configuration.
|
||||||
|
"""
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
credentials: dict
|
||||||
|
|
||||||
|
|
||||||
|
class CustomConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider custom configuration.
|
||||||
|
"""
|
||||||
|
provider: Optional[CustomProviderConfiguration] = None
|
||||||
|
models: list[CustomModelConfiguration] = []
|
118
api/core/entities/queue_entities.py
Normal file
118
api/core/entities/queue_entities.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
|
||||||
|
|
||||||
|
class QueueEvent(Enum):
|
||||||
|
"""
|
||||||
|
QueueEvent enum
|
||||||
|
"""
|
||||||
|
MESSAGE = "message"
|
||||||
|
MESSAGE_REPLACE = "message-replace"
|
||||||
|
MESSAGE_END = "message-end"
|
||||||
|
RETRIEVER_RESOURCES = "retriever-resources"
|
||||||
|
ANNOTATION_REPLY = "annotation-reply"
|
||||||
|
AGENT_THOUGHT = "agent-thought"
|
||||||
|
ERROR = "error"
|
||||||
|
PING = "ping"
|
||||||
|
STOP = "stop"
|
||||||
|
|
||||||
|
|
||||||
|
class AppQueueEvent(BaseModel):
|
||||||
|
"""
|
||||||
|
QueueEvent entity
|
||||||
|
"""
|
||||||
|
event: QueueEvent
|
||||||
|
|
||||||
|
|
||||||
|
class QueueMessageEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueMessageEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.MESSAGE
|
||||||
|
chunk: LLMResultChunk
|
||||||
|
|
||||||
|
|
||||||
|
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueMessageReplaceEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.MESSAGE_REPLACE
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueRetrieverResourcesEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.RETRIEVER_RESOURCES
|
||||||
|
retriever_resources: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
AnnotationReplyEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.ANNOTATION_REPLY
|
||||||
|
message_annotation_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueueMessageEndEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueMessageEndEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.MESSAGE_END
|
||||||
|
llm_result: LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class QueueAgentThoughtEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueAgentThoughtEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.AGENT_THOUGHT
|
||||||
|
agent_thought_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueueErrorEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueErrorEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.ERROR
|
||||||
|
error: Any
|
||||||
|
|
||||||
|
|
||||||
|
class QueuePingEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueuePingEvent entity
|
||||||
|
"""
|
||||||
|
event = QueueEvent.PING
|
||||||
|
|
||||||
|
|
||||||
|
class QueueStopEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueStopEvent entity
|
||||||
|
"""
|
||||||
|
class StopBy(Enum):
|
||||||
|
"""
|
||||||
|
Stop by enum
|
||||||
|
"""
|
||||||
|
USER_MANUAL = "user-manual"
|
||||||
|
ANNOTATION_REPLY = "annotation-reply"
|
||||||
|
OUTPUT_MODERATION = "output-moderation"
|
||||||
|
|
||||||
|
event = QueueEvent.STOP
|
||||||
|
stopped_by: StopBy
|
||||||
|
|
||||||
|
|
||||||
|
class QueueMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
QueueMessage entity
|
||||||
|
"""
|
||||||
|
task_id: str
|
||||||
|
message_id: str
|
||||||
|
conversation_id: str
|
||||||
|
app_mode: str
|
||||||
|
event: AppQueueEvent
|
|
@ -14,26 +14,6 @@ class LLMBadRequestError(LLMError):
|
||||||
description = "Bad Request"
|
description = "Bad Request"
|
||||||
|
|
||||||
|
|
||||||
class LLMAPIConnectionError(LLMError):
|
|
||||||
"""Raised when the LLM returns API connection error."""
|
|
||||||
description = "API Connection Error"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMAPIUnavailableError(LLMError):
|
|
||||||
"""Raised when the LLM returns API unavailable error."""
|
|
||||||
description = "API Unavailable Error"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRateLimitError(LLMError):
|
|
||||||
"""Raised when the LLM returns rate limit error."""
|
|
||||||
description = "Rate Limit Error"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMAuthorizationError(LLMError):
|
|
||||||
"""Raised when the LLM returns authorization error."""
|
|
||||||
description = "Authorization Error"
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderTokenNotInitError(Exception):
|
class ProviderTokenNotInitError(Exception):
|
||||||
"""
|
"""
|
||||||
Custom exception raised when the provider token is not initialized.
|
Custom exception raised when the provider token is not initialized.
|
35
api/core/external_data_tool/weather_search/schema.json
Normal file
35
api/core/external_data_tool/weather_search/schema.json
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
{
|
||||||
|
"label": {
|
||||||
|
"en-US": "Weather Search",
|
||||||
|
"zh-Hans": "天气查询"
|
||||||
|
},
|
||||||
|
"form_schema": [
|
||||||
|
{
|
||||||
|
"type": "select",
|
||||||
|
"label": {
|
||||||
|
"en-US": "Temperature Unit",
|
||||||
|
"zh-Hans": "温度单位"
|
||||||
|
},
|
||||||
|
"variable": "temperature_unit",
|
||||||
|
"required": true,
|
||||||
|
"options": [
|
||||||
|
{
|
||||||
|
"label": {
|
||||||
|
"en-US": "Fahrenheit",
|
||||||
|
"zh-Hans": "华氏度"
|
||||||
|
},
|
||||||
|
"value": "fahrenheit"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": {
|
||||||
|
"en-US": "Centigrade",
|
||||||
|
"zh-Hans": "摄氏度"
|
||||||
|
},
|
||||||
|
"value": "centigrade"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default": "centigrade",
|
||||||
|
"placeholder": "Please select temperature unit"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
45
api/core/external_data_tool/weather_search/weather_search.py
Normal file
45
api/core/external_data_tool/weather_search/weather_search.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.external_data_tool.base import ExternalDataTool
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherSearch(ExternalDataTool):
|
||||||
|
"""
|
||||||
|
The name of custom type must be unique, keep the same with directory and file name.
|
||||||
|
"""
|
||||||
|
name: str = "weather_search"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
schema.json validation. It will be called when user save the config.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
config = {
|
||||||
|
"temperature_unit": "centigrade"
|
||||||
|
}
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the variables of form config
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not config.get('temperature_unit'):
|
||||||
|
raise ValueError('temperature unit is required')
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Query the external data tool.
|
||||||
|
|
||||||
|
:param inputs: user inputs
|
||||||
|
:param query: the query of chat app
|
||||||
|
:return: the tool query result
|
||||||
|
"""
|
||||||
|
city = inputs.get('city')
|
||||||
|
temperature_unit = self.config.get('temperature_unit')
|
||||||
|
|
||||||
|
if temperature_unit == 'fahrenheit':
|
||||||
|
return f'Weather in {city} is 32°F'
|
||||||
|
else:
|
||||||
|
return f'Weather in {city} is 0°C'
|
325
api/core/features/agent_runner.py
Normal file
325
api/core/features/agent_runner.py
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
import logging
|
||||||
|
from typing import cast, Optional, List
|
||||||
|
|
||||||
|
from langchain import WikipediaAPIWrapper
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.tools import BaseTool, WikipediaQueryRun, Tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||||
|
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
|
||||||
|
from core.application_queue_manager import ApplicationQueueManager
|
||||||
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
|
from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
|
||||||
|
AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.tool.current_datetime_tool import DatetimeTool
|
||||||
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||||
|
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||||
|
from core.tool.web_reader_tool import WebReaderTool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset
|
||||||
|
from models.model import Message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunnerFeature:
|
||||||
|
def __init__(self, tenant_id: str,
|
||||||
|
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||||
|
model_config: ModelConfigEntity,
|
||||||
|
config: AgentEntity,
|
||||||
|
queue_manager: ApplicationQueueManager,
|
||||||
|
message: Message,
|
||||||
|
user_id: str,
|
||||||
|
agent_llm_callback: AgentLLMCallback,
|
||||||
|
callback: AgentLoopGatherCallbackHandler,
|
||||||
|
memory: Optional[TokenBufferMemory] = None,) -> None:
|
||||||
|
"""
|
||||||
|
Agent runner
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param app_orchestration_config: app orchestration config
|
||||||
|
:param model_config: model config
|
||||||
|
:param config: dataset config
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param message: message
|
||||||
|
:param user_id: user id
|
||||||
|
:param agent_llm_callback: agent llm callback
|
||||||
|
:param callback: callback
|
||||||
|
:param memory: memory
|
||||||
|
"""
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.app_orchestration_config = app_orchestration_config
|
||||||
|
self.model_config = model_config
|
||||||
|
self.config = config
|
||||||
|
self.queue_manager = queue_manager
|
||||||
|
self.message = message
|
||||||
|
self.user_id = user_id
|
||||||
|
self.agent_llm_callback = agent_llm_callback
|
||||||
|
self.callback = callback
|
||||||
|
self.memory = memory
|
||||||
|
|
||||||
|
def run(self, query: str,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve agent loop result.
|
||||||
|
:param query: query
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider = self.config.provider
|
||||||
|
model = self.config.model
|
||||||
|
tool_configs = self.config.tools
|
||||||
|
|
||||||
|
# check model is support tool calling
|
||||||
|
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
# get model schema
|
||||||
|
model_schema = model_type_instance.get_model_schema(
|
||||||
|
model=model,
|
||||||
|
credentials=self.model_config.credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_schema:
|
||||||
|
return None
|
||||||
|
|
||||||
|
planning_strategy = PlanningStrategy.REACT
|
||||||
|
features = model_schema.features
|
||||||
|
if features:
|
||||||
|
if ModelFeature.TOOL_CALL in features \
|
||||||
|
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||||
|
planning_strategy = PlanningStrategy.FUNCTION_CALL
|
||||||
|
|
||||||
|
tools = self.to_tools(
|
||||||
|
tool_configs=tool_configs,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
callbacks=[self.callback, DifyStdOutCallbackHandler()],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(tools) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_configuration = AgentConfiguration(
|
||||||
|
strategy=planning_strategy,
|
||||||
|
model_config=self.model_config,
|
||||||
|
tools=tools,
|
||||||
|
memory=self.memory,
|
||||||
|
max_iterations=10,
|
||||||
|
max_execution_time=400.0,
|
||||||
|
early_stopping_method="generate",
|
||||||
|
agent_llm_callback=self.agent_llm_callback,
|
||||||
|
callbacks=[self.callback, DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_executor = AgentExecutor(agent_configuration)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# check if should use agent
|
||||||
|
should_use_agent = agent_executor.should_use_agent(query)
|
||||||
|
if not should_use_agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = agent_executor.run(query)
|
||||||
|
return result.output
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception("agent_executor run failed")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_tools(self, tool_configs: list[AgentToolEntity],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
callbacks: list[BaseCallbackHandler]) \
|
||||||
|
-> Optional[List[BaseTool]]:
|
||||||
|
"""
|
||||||
|
Convert tool configs to tools
|
||||||
|
:param tool_configs: tool configs
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param callbacks: callbacks
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
for tool_config in tool_configs:
|
||||||
|
tool = None
|
||||||
|
if tool_config.tool_id == "dataset":
|
||||||
|
tool = self.to_dataset_retriever_tool(
|
||||||
|
tool_config=tool_config.config,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
elif tool_config.tool_id == "web_reader":
|
||||||
|
tool = self.to_web_reader_tool(
|
||||||
|
tool_config=tool_config.config,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
elif tool_config.tool_id == "google_search":
|
||||||
|
tool = self.to_google_search_tool(
|
||||||
|
tool_config=tool_config.config,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
elif tool_config.tool_id == "wikipedia":
|
||||||
|
tool = self.to_wikipedia_tool(
|
||||||
|
tool_config=tool_config.config,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
elif tool_config.tool_id == "current_datetime":
|
||||||
|
tool = self.to_current_datetime_tool(
|
||||||
|
tool_config=tool_config.config,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool:
|
||||||
|
if tool.callbacks is not None:
|
||||||
|
tool.callbacks.extend(callbacks)
|
||||||
|
else:
|
||||||
|
tool.callbacks = callbacks
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||||
|
invoke_from: InvokeFrom) \
|
||||||
|
-> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
|
:param tool_config: tool config
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
"""
|
||||||
|
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
|
||||||
|
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(
|
||||||
|
queue_manager=self.queue_manager,
|
||||||
|
app_id=self.message.app_id,
|
||||||
|
message_id=self.message.id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
# get dataset from dataset id
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == self.tenant_id,
|
||||||
|
Dataset.id == tool_config.get("id")
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# pass if dataset is not available
|
||||||
|
if not dataset:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# pass if dataset is not available
|
||||||
|
if (dataset and dataset.available_document_count == 0
|
||||||
|
and dataset.available_document_count == 0):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get retrieval model config
|
||||||
|
default_retrieval_model = {
|
||||||
|
'search_method': 'semantic_search',
|
||||||
|
'reranking_enable': False,
|
||||||
|
'reranking_model': {
|
||||||
|
'reranking_provider_name': '',
|
||||||
|
'reranking_model_name': ''
|
||||||
|
},
|
||||||
|
'top_k': 2,
|
||||||
|
'score_threshold_enabled': False
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieval_model_config = dataset.retrieval_model \
|
||||||
|
if dataset.retrieval_model else default_retrieval_model
|
||||||
|
|
||||||
|
# get top k
|
||||||
|
top_k = retrieval_model_config['top_k']
|
||||||
|
|
||||||
|
# get score threshold
|
||||||
|
score_threshold = None
|
||||||
|
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||||
|
if score_threshold_enabled:
|
||||||
|
score_threshold = retrieval_model_config.get("score_threshold")
|
||||||
|
|
||||||
|
tool = DatasetRetrieverTool.from_dataset(
|
||||||
|
dataset=dataset,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
hit_callbacks=[hit_callback],
|
||||||
|
return_resource=show_retrieve_source,
|
||||||
|
retriever_from=invoke_from.to_source()
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def to_web_reader_tool(self, tool_config: dict,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A tool for reading web pages
|
||||||
|
:param tool_config: tool config
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
model_parameters = {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
tool = WebReaderTool(
|
||||||
|
model_config=self.model_config,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
max_chunk_length=4000,
|
||||||
|
continue_reading=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def to_google_search_tool(self, tool_config: dict,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A tool for performing a Google search and extracting snippets and webpages
|
||||||
|
:param tool_config: tool config
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||||
|
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||||
|
if not func_kwargs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool = Tool(
|
||||||
|
name="google_search",
|
||||||
|
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||||
|
"when you need to search for something you don't know or when your information "
|
||||||
|
"is not up to date. "
|
||||||
|
"Input should be a search query.",
|
||||||
|
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||||
|
args_schema=OptimizedSerpAPIInput
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def to_current_datetime_tool(self, tool_config: dict,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A tool for getting the current date and time
|
||||||
|
:param tool_config: tool config
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return DatetimeTool()
|
||||||
|
|
||||||
|
def to_wikipedia_tool(self, tool_config: dict,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A tool for searching Wikipedia
|
||||||
|
:param tool_config: tool config
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
class WikipediaInput(BaseModel):
|
||||||
|
query: str = Field(..., description="search query.")
|
||||||
|
|
||||||
|
return WikipediaQueryRun(
|
||||||
|
name="wikipedia",
|
||||||
|
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||||
|
args_schema=WikipediaInput
|
||||||
|
)
|
119
api/core/features/annotation_reply.py
Normal file
119
api/core/features/annotation_reply.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
|
from core.entities.application_entities import InvokeFrom
|
||||||
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset
|
||||||
|
from models.model import App, Message, AppAnnotationSetting, MessageAnnotation
|
||||||
|
from services.annotation_service import AppAnnotationService
|
||||||
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyFeature:
|
||||||
|
def query(self, app_record: App,
|
||||||
|
message: Message,
|
||||||
|
query: str,
|
||||||
|
user_id: str,
|
||||||
|
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||||
|
"""
|
||||||
|
Query app annotations to reply
|
||||||
|
:param app_record: app record
|
||||||
|
:param message: message
|
||||||
|
:param query: query
|
||||||
|
:param user_id: user id
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||||
|
AppAnnotationSetting.app_id == app_record.id).first()
|
||||||
|
|
||||||
|
if not annotation_setting:
|
||||||
|
return None
|
||||||
|
|
||||||
|
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||||
|
|
||||||
|
try:
|
||||||
|
score_threshold = annotation_setting.score_threshold or 1
|
||||||
|
embedding_provider_name = collection_binding_detail.provider_name
|
||||||
|
embedding_model_name = collection_binding_detail.model_name
|
||||||
|
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
provider=embedding_provider_name,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=embedding_model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# get embedding model
|
||||||
|
embeddings = CacheEmbedding(model_instance)
|
||||||
|
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
embedding_provider_name,
|
||||||
|
embedding_model_name,
|
||||||
|
'annotation'
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = Dataset(
|
||||||
|
id=app_record.id,
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
indexing_technique='high_quality',
|
||||||
|
embedding_model_provider=embedding_provider_name,
|
||||||
|
embedding_model=embedding_model_name,
|
||||||
|
collection_binding_id=dataset_collection_binding.id
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_index = VectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings,
|
||||||
|
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = vector_index.search(
|
||||||
|
query=query,
|
||||||
|
search_type='similarity_score_threshold',
|
||||||
|
search_kwargs={
|
||||||
|
'k': 1,
|
||||||
|
'score_threshold': score_threshold,
|
||||||
|
'filter': {
|
||||||
|
'group_id': [dataset.id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
annotation_id = documents[0].metadata['annotation_id']
|
||||||
|
score = documents[0].metadata['score']
|
||||||
|
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||||
|
if annotation:
|
||||||
|
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
|
||||||
|
from_source = 'api'
|
||||||
|
else:
|
||||||
|
from_source = 'console'
|
||||||
|
|
||||||
|
# insert annotation history
|
||||||
|
AppAnnotationService.add_annotation_history(annotation.id,
|
||||||
|
app_record.id,
|
||||||
|
annotation.question,
|
||||||
|
annotation.content,
|
||||||
|
query,
|
||||||
|
user_id,
|
||||||
|
message.id,
|
||||||
|
from_source,
|
||||||
|
score)
|
||||||
|
|
||||||
|
return annotation
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
181
api/core/features/dataset_retrieval.py
Normal file
181
api/core/features/dataset_retrieval.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
from typing import cast, Optional, List
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity
|
||||||
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||||
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRetrievalFeature:
|
||||||
|
def retrieve(self, tenant_id: str,
|
||||||
|
model_config: ModelConfigEntity,
|
||||||
|
config: DatasetEntity,
|
||||||
|
query: str,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
show_retrieve_source: bool,
|
||||||
|
hit_callback: DatasetIndexToolCallbackHandler,
|
||||||
|
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve dataset.
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param model_config: model config
|
||||||
|
:param config: dataset config
|
||||||
|
:param query: query
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param show_retrieve_source: show retrieve source
|
||||||
|
:param hit_callback: hit callback
|
||||||
|
:param memory: memory
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
dataset_ids = config.dataset_ids
|
||||||
|
retrieve_config = config.retrieve_config
|
||||||
|
|
||||||
|
# check model is support tool calling
|
||||||
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
# get model schema
|
||||||
|
model_schema = model_type_instance.get_model_schema(
|
||||||
|
model=model_config.model,
|
||||||
|
credentials=model_config.credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_schema:
|
||||||
|
return None
|
||||||
|
|
||||||
|
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||||
|
features = model_schema.features
|
||||||
|
if features:
|
||||||
|
if ModelFeature.TOOL_CALL in features \
|
||||||
|
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||||
|
planning_strategy = PlanningStrategy.ROUTER
|
||||||
|
|
||||||
|
dataset_retriever_tools = self.to_dataset_retriever_tool(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_ids=dataset_ids,
|
||||||
|
retrieve_config=retrieve_config,
|
||||||
|
return_resource=show_retrieve_source,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
hit_callback=hit_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(dataset_retriever_tools) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_configuration = AgentConfiguration(
|
||||||
|
strategy=planning_strategy,
|
||||||
|
model_config=model_config,
|
||||||
|
tools=dataset_retriever_tools,
|
||||||
|
memory=memory,
|
||||||
|
max_iterations=10,
|
||||||
|
max_execution_time=400.0,
|
||||||
|
early_stopping_method="generate"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_executor = AgentExecutor(agent_configuration)
|
||||||
|
|
||||||
|
should_use_agent = agent_executor.should_use_agent(query)
|
||||||
|
if not should_use_agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = agent_executor.run(query)
|
||||||
|
|
||||||
|
return result.output
|
||||||
|
|
||||||
|
def to_dataset_retriever_tool(self, tenant_id: str,
|
||||||
|
dataset_ids: list[str],
|
||||||
|
retrieve_config: DatasetRetrieveConfigEntity,
|
||||||
|
return_resource: bool,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||||
|
-> Optional[List[BaseTool]]:
|
||||||
|
"""
|
||||||
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param dataset_ids: dataset ids
|
||||||
|
:param retrieve_config: retrieve config
|
||||||
|
:param return_resource: return resource
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param hit_callback: hit callback
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
available_datasets = []
|
||||||
|
for dataset_id in dataset_ids:
|
||||||
|
# get dataset from dataset id
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# pass if dataset is not available
|
||||||
|
if not dataset:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pass if dataset is not available
|
||||||
|
if (dataset and dataset.available_document_count == 0
|
||||||
|
and dataset.available_document_count == 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
available_datasets.append(dataset)
|
||||||
|
|
||||||
|
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
|
# get retrieval model config
|
||||||
|
default_retrieval_model = {
|
||||||
|
'search_method': 'semantic_search',
|
||||||
|
'reranking_enable': False,
|
||||||
|
'reranking_model': {
|
||||||
|
'reranking_provider_name': '',
|
||||||
|
'reranking_model_name': ''
|
||||||
|
},
|
||||||
|
'top_k': 2,
|
||||||
|
'score_threshold_enabled': False
|
||||||
|
}
|
||||||
|
|
||||||
|
for dataset in available_datasets:
|
||||||
|
retrieval_model_config = dataset.retrieval_model \
|
||||||
|
if dataset.retrieval_model else default_retrieval_model
|
||||||
|
|
||||||
|
# get top k
|
||||||
|
top_k = retrieval_model_config['top_k']
|
||||||
|
|
||||||
|
# get score threshold
|
||||||
|
score_threshold = None
|
||||||
|
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||||
|
if score_threshold_enabled:
|
||||||
|
score_threshold = retrieval_model_config.get("score_threshold")
|
||||||
|
|
||||||
|
tool = DatasetRetrieverTool.from_dataset(
|
||||||
|
dataset=dataset,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
hit_callbacks=[hit_callback],
|
||||||
|
return_resource=return_resource,
|
||||||
|
retriever_from=invoke_from.to_source()
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||||
|
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||||
|
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
top_k=retrieve_config.top_k or 2,
|
||||||
|
score_threshold=(retrieve_config.score_threshold or 0.5)
|
||||||
|
if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None,
|
||||||
|
hit_callbacks=[hit_callback],
|
||||||
|
return_resource=return_resource,
|
||||||
|
retriever_from=invoke_from.to_source(),
|
||||||
|
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||||
|
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return tools
|
96
api/core/features/external_data_fetch.py
Normal file
96
api/core/features/external_data_fetch.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
import concurrent
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
from flask import current_app, Flask
|
||||||
|
|
||||||
|
from core.entities.application_entities import ExternalDataVariableEntity
|
||||||
|
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalDataFetchFeature:
|
||||||
|
def fetch(self, tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
external_data_tools: list[ExternalDataVariableEntity],
|
||||||
|
inputs: dict,
|
||||||
|
query: str) -> dict:
|
||||||
|
"""
|
||||||
|
Fill in variable inputs from external data tools if exists.
|
||||||
|
|
||||||
|
:param tenant_id: workspace id
|
||||||
|
:param app_id: app id
|
||||||
|
:param external_data_tools: external data tools configs
|
||||||
|
:param inputs: the inputs
|
||||||
|
:param query: the query
|
||||||
|
:return: the filled inputs
|
||||||
|
"""
|
||||||
|
# Group tools by type and config
|
||||||
|
grouped_tools = {}
|
||||||
|
for tool in external_data_tools:
|
||||||
|
tool_key = (tool.type, json.dumps(tool.config, sort_keys=True))
|
||||||
|
grouped_tools.setdefault(tool_key, []).append(tool)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = {}
|
||||||
|
for tool in external_data_tools:
|
||||||
|
future = executor.submit(
|
||||||
|
self._query_external_data_tool,
|
||||||
|
current_app._get_current_object(),
|
||||||
|
tenant_id,
|
||||||
|
app_id,
|
||||||
|
tool,
|
||||||
|
inputs,
|
||||||
|
query
|
||||||
|
)
|
||||||
|
|
||||||
|
futures[future] = tool
|
||||||
|
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
tool_variable, result = future.result()
|
||||||
|
results[tool_variable] = result
|
||||||
|
|
||||||
|
inputs.update(results)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _query_external_data_tool(self, flask_app: Flask,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
external_data_tool: ExternalDataVariableEntity,
|
||||||
|
inputs: dict,
|
||||||
|
query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Query external data tool.
|
||||||
|
:param flask_app: flask app
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param app_id: app id
|
||||||
|
:param external_data_tool: external data tool
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
with flask_app.app_context():
|
||||||
|
tool_variable = external_data_tool.variable
|
||||||
|
tool_type = external_data_tool.type
|
||||||
|
tool_config = external_data_tool.config
|
||||||
|
|
||||||
|
external_data_tool_factory = ExternalDataToolFactory(
|
||||||
|
name=tool_type,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
variable=tool_variable,
|
||||||
|
config=tool_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# query external data tool
|
||||||
|
result = external_data_tool_factory.query(
|
||||||
|
inputs=inputs,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_variable, result
|
32
api/core/features/hosting_moderation.py
Normal file
32
api/core/features/hosting_moderation.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from core.entities.application_entities import ApplicationGenerateEntity
|
||||||
|
from core.helper import moderation
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HostingModerationFeature:
|
||||||
|
def check(self, application_generate_entity: ApplicationGenerateEntity,
|
||||||
|
prompt_messages: list[PromptMessage]) -> bool:
|
||||||
|
"""
|
||||||
|
Check hosting moderation
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||||
|
model_config = app_orchestration_config.model_config
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for prompt_message in prompt_messages:
|
||||||
|
if isinstance(prompt_message.content, str):
|
||||||
|
text += prompt_message.content + "\n"
|
||||||
|
|
||||||
|
moderation_result = moderation.check_moderation(
|
||||||
|
model_config,
|
||||||
|
text
|
||||||
|
)
|
||||||
|
|
||||||
|
return moderation_result
|
50
api/core/features/moderation.py
Normal file
50
api/core/features/moderation.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from core.entities.application_entities import AppOrchestrationConfigEntity
|
||||||
|
from core.moderation.base import ModerationAction, ModerationException
|
||||||
|
from core.moderation.factory import ModerationFactory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationFeature:
|
||||||
|
def check(self, app_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||||
|
inputs: dict,
|
||||||
|
query: str) -> Tuple[bool, dict, str]:
|
||||||
|
"""
|
||||||
|
Process sensitive_word_avoidance.
|
||||||
|
:param app_id: app id
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param app_orchestration_config_entity: app orchestration config entity
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not app_orchestration_config_entity.sensitive_word_avoidance:
|
||||||
|
return False, inputs, query
|
||||||
|
|
||||||
|
sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance
|
||||||
|
moderation_type = sensitive_word_avoidance_config.type
|
||||||
|
|
||||||
|
moderation_factory = ModerationFactory(
|
||||||
|
name=moderation_type,
|
||||||
|
app_id=app_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=sensitive_word_avoidance_config.config
|
||||||
|
)
|
||||||
|
|
||||||
|
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||||
|
|
||||||
|
if not moderation_result.flagged:
|
||||||
|
return False, inputs, query
|
||||||
|
|
||||||
|
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||||
|
raise ModerationException(moderation_result.preset_response)
|
||||||
|
elif moderation_result.action == ModerationAction.OVERRIDED:
|
||||||
|
inputs = moderation_result.inputs
|
||||||
|
query = moderation_result.query
|
||||||
|
|
||||||
|
return True, inputs, query
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.file.upload_file_parser import UploadFileParser
|
from core.file.upload_file_parser import UploadFileParser
|
||||||
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
@ -50,14 +50,14 @@ class FileObj(BaseModel):
|
||||||
return self._get_data(force_url=True)
|
return self._get_data(force_url=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_message_file(self) -> PromptMessageFile:
|
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||||
if self.type == FileType.IMAGE:
|
if self.type == FileType.IMAGE:
|
||||||
image_config = self.file_config.get('image')
|
image_config = self.file_config.get('image')
|
||||||
|
|
||||||
return ImagePromptMessageFile(
|
return ImagePromptMessageContent(
|
||||||
data=self.data,
|
data=self.data,
|
||||||
detail=ImagePromptMessageFile.DETAIL.HIGH
|
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||||
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
|
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||||
|
|
|
@ -3,10 +3,10 @@ import logging
|
||||||
|
|
||||||
from langchain.schema import OutputParserException
|
from langchain.schema import OutputParserException
|
||||||
|
|
||||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
from core.model_manager import ModelManager
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage
|
||||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||||
|
|
||||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||||
|
@ -26,17 +26,22 @@ class LLMGenerator:
|
||||||
|
|
||||||
prompt += query + "\n"
|
prompt += query + "\n"
|
||||||
|
|
||||||
model_instance = ModelFactory.get_text_generation_model(
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_kwargs=ModelKwargs(
|
model_type=ModelType.LLM,
|
||||||
temperature=1,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = [PromptMessage(content=prompt)]
|
prompts = [UserPromptMessage(content=prompt)]
|
||||||
response = model_instance.run(prompts)
|
response = model_instance.invoke_llm(
|
||||||
answer = response.content
|
prompt_messages=prompts,
|
||||||
|
model_parameters={
|
||||||
|
"max_tokens": 100,
|
||||||
|
"temperature": 1
|
||||||
|
},
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
answer = response.message.content
|
||||||
|
|
||||||
result_dict = json.loads(answer)
|
result_dict = json.loads(answer)
|
||||||
answer = result_dict['Your Output']
|
answer = result_dict['Your Output']
|
||||||
|
@ -62,22 +67,28 @@ class LLMGenerator:
|
||||||
})
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_instance = ModelFactory.get_text_generation_model(
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_kwargs=ModelKwargs(
|
model_type=ModelType.LLM,
|
||||||
max_tokens=256,
|
|
||||||
temperature=0
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError:
|
except InvokeAuthorizationError:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
prompt_messages = [PromptMessage(content=prompt)]
|
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = model_instance.run(prompt_messages)
|
response = model_instance.invoke_llm(
|
||||||
questions = output_parser.parse(output.content)
|
prompt_messages=prompt_messages,
|
||||||
except LLMError:
|
model_parameters={
|
||||||
|
"max_tokens": 256,
|
||||||
|
"temperature": 0
|
||||||
|
},
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
questions = output_parser.parse(response.message.content)
|
||||||
|
except InvokeError:
|
||||||
questions = []
|
questions = []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
@ -105,20 +116,26 @@ class LLMGenerator:
|
||||||
remove_template_variables=False
|
remove_template_variables=False
|
||||||
)
|
)
|
||||||
|
|
||||||
model_instance = ModelFactory.get_text_generation_model(
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_kwargs=ModelKwargs(
|
model_type=ModelType.LLM,
|
||||||
max_tokens=512,
|
|
||||||
temperature=0
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_messages = [PromptMessage(content=prompt)]
|
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = model_instance.run(prompt_messages)
|
response = model_instance.invoke_llm(
|
||||||
rule_config = output_parser.parse(output.content)
|
prompt_messages=prompt_messages,
|
||||||
except LLMError as e:
|
model_parameters={
|
||||||
|
"max_tokens": 512,
|
||||||
|
"temperature": 0
|
||||||
|
},
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_config = output_parser.parse(response.message.content)
|
||||||
|
except InvokeError as e:
|
||||||
raise e
|
raise e
|
||||||
except OutputParserException:
|
except OutputParserException:
|
||||||
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
|
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
|
||||||
|
@ -136,18 +153,24 @@ class LLMGenerator:
|
||||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||||
|
|
||||||
model_instance = ModelFactory.get_text_generation_model(
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_kwargs=ModelKwargs(
|
model_type=ModelType.LLM,
|
||||||
max_tokens=2000
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = [
|
prompt_messages = [
|
||||||
PromptMessage(content=prompt, type=MessageType.SYSTEM),
|
SystemPromptMessage(content=prompt),
|
||||||
PromptMessage(content=query)
|
UserPromptMessage(content=query)
|
||||||
]
|
]
|
||||||
|
|
||||||
response = model_instance.run(prompts)
|
response = model_instance.invoke_llm(
|
||||||
answer = response.content
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters={
|
||||||
|
"max_tokens": 2000
|
||||||
|
},
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = response.message.content
|
||||||
return answer.strip()
|
return answer.strip()
|
||||||
|
|
|
@ -18,3 +18,17 @@ def encrypt_token(tenant_id: str, token: str):
|
||||||
|
|
||||||
def decrypt_token(tenant_id: str, token: str):
|
def decrypt_token(tenant_id: str, token: str):
|
||||||
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_decrypt_token(tenant_id: str, tokens: list[str]):
|
||||||
|
rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)
|
||||||
|
|
||||||
|
return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]
|
||||||
|
|
||||||
|
|
||||||
|
def get_decrypt_decoding(tenant_id: str):
|
||||||
|
return rsa.get_decrypt_decoding(tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
|
||||||
|
return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)
|
||||||
|
|
22
api/core/helper/lru_cache.py
Normal file
22
api/core/helper/lru_cache.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache:
|
||||||
|
def __init__(self, capacity: int):
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.capacity = capacity
|
||||||
|
|
||||||
|
def get(self, key: Any) -> Any:
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||||
|
return self.cache[key]
|
||||||
|
|
||||||
|
def put(self, key: Any, value: Any) -> None:
|
||||||
|
if key in self.cache:
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
self.cache[key] = value
|
||||||
|
if len(self.cache) > self.capacity:
|
||||||
|
self.cache.popitem(last=False) # pop the first item
|
|
@ -1,18 +1,27 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import openai
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from extensions.ext_hosting_provider import hosting_configuration
|
||||||
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
|
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_moderation(model_config: ModelConfigEntity, text: str) -> bool:
|
||||||
|
moderation_config = hosting_configuration.moderation_config
|
||||||
|
if (moderation_config and moderation_config.enabled is True
|
||||||
|
and 'openai' in hosting_configuration.provider_map
|
||||||
|
and hosting_configuration.provider_map['openai'].enabled is True
|
||||||
|
):
|
||||||
|
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
|
||||||
|
provider_name = model_config.provider
|
||||||
|
if using_provider_type == ProviderType.SYSTEM \
|
||||||
|
and provider_name in moderation_config.providers:
|
||||||
|
hosting_openai_config = hosting_configuration.provider_map['openai']
|
||||||
|
|
||||||
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
|
||||||
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
|
|
||||||
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
|
||||||
and model_provider.provider_name in hosted_config.moderation.providers:
|
|
||||||
# 2000 text per chunk
|
# 2000 text per chunk
|
||||||
length = 2000
|
length = 2000
|
||||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||||
|
@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||||
text_chunk = random.choice(text_chunks)
|
text_chunk = random.choice(text_chunks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
model_type_instance = OpenAIModerationModel()
|
||||||
api_key=hosted_model_providers.openai.api_key)
|
moderation_result = model_type_instance.invoke(
|
||||||
|
model='text-moderation-stable',
|
||||||
|
credentials=hosting_openai_config.credentials,
|
||||||
|
text=text_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
if moderation_result is True:
|
||||||
|
return True
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.exception(ex)
|
logger.exception(ex)
|
||||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
|
||||||
|
|
||||||
for result in moderation_result.results:
|
return False
|
||||||
if result['flagged'] is True:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
213
api/core/hosting_configuration.py
Normal file
213
api/core/hosting_configuration.py
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.entities.provider_entities import QuotaUnit
|
||||||
|
from models.provider import ProviderQuotaType
|
||||||
|
|
||||||
|
|
||||||
|
class HostingQuota(BaseModel):
|
||||||
|
quota_type: ProviderQuotaType
|
||||||
|
restrict_llms: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class TrialHostingQuota(HostingQuota):
|
||||||
|
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
|
||||||
|
quota_limit: int = 0
|
||||||
|
"""Quota limit for the hosting provider models. -1 means unlimited."""
|
||||||
|
|
||||||
|
|
||||||
|
class PaidHostingQuota(HostingQuota):
|
||||||
|
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
|
||||||
|
stripe_price_id: str = None
|
||||||
|
increase_quota: int = 1
|
||||||
|
min_quantity: int = 20
|
||||||
|
max_quantity: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class FreeHostingQuota(HostingQuota):
|
||||||
|
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
|
||||||
|
|
||||||
|
|
||||||
|
class HostingProvider(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
credentials: Optional[dict] = None
|
||||||
|
quota_unit: Optional[QuotaUnit] = None
|
||||||
|
quotas: list[HostingQuota] = []
|
||||||
|
|
||||||
|
|
||||||
|
class HostedModerationConfig(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
providers: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class HostingConfiguration:
|
||||||
|
provider_map: dict[str, HostingProvider] = {}
|
||||||
|
moderation_config: HostedModerationConfig = None
|
||||||
|
|
||||||
|
def init_app(self, app: Flask):
|
||||||
|
if app.config.get('EDITION') != 'CLOUD':
|
||||||
|
return
|
||||||
|
|
||||||
|
self.provider_map["openai"] = self.init_openai()
|
||||||
|
self.provider_map["anthropic"] = self.init_anthropic()
|
||||||
|
self.provider_map["minimax"] = self.init_minimax()
|
||||||
|
self.provider_map["spark"] = self.init_spark()
|
||||||
|
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||||
|
|
||||||
|
self.moderation_config = self.init_moderation_config()
|
||||||
|
|
||||||
|
def init_openai(self) -> HostingProvider:
|
||||||
|
quota_unit = QuotaUnit.TIMES
|
||||||
|
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
|
||||||
|
credentials = {
|
||||||
|
"openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.environ.get("HOSTED_OPENAI_API_BASE"):
|
||||||
|
credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
|
||||||
|
|
||||||
|
if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||||
|
credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||||
|
|
||||||
|
quotas = []
|
||||||
|
hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||||
|
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||||
|
trial_quota = TrialHostingQuota(
|
||||||
|
quota_limit=hosted_quota_limit,
|
||||||
|
restrict_llms=[
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-1106",
|
||||||
|
"gpt-3.5-turbo-instruct",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
"text-davinci-003"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
quotas.append(trial_quota)
|
||||||
|
|
||||||
|
if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
|
||||||
|
"HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
|
||||||
|
paid_quota = PaidHostingQuota(
|
||||||
|
stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||||
|
increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
|
||||||
|
min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
|
||||||
|
max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
|
||||||
|
)
|
||||||
|
quotas.append(paid_quota)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=True,
|
||||||
|
credentials=credentials,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
quotas=quotas
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=False,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_anthropic(self) -> HostingProvider:
|
||||||
|
quota_unit = QuotaUnit.TOKENS
|
||||||
|
if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
|
||||||
|
credentials = {
|
||||||
|
"anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||||
|
credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
|
||||||
|
|
||||||
|
quotas = []
|
||||||
|
hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||||
|
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||||
|
trial_quota = TrialHostingQuota(
|
||||||
|
quota_limit=hosted_quota_limit
|
||||||
|
)
|
||||||
|
quotas.append(trial_quota)
|
||||||
|
|
||||||
|
if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
|
||||||
|
"HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
|
||||||
|
paid_quota = PaidHostingQuota(
|
||||||
|
stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||||
|
increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
|
||||||
|
min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
|
||||||
|
max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
|
||||||
|
)
|
||||||
|
quotas.append(paid_quota)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=True,
|
||||||
|
credentials=credentials,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
quotas=quotas
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=False,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_minimax(self) -> HostingProvider:
|
||||||
|
quota_unit = QuotaUnit.TOKENS
|
||||||
|
if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
|
||||||
|
quotas = [FreeHostingQuota()]
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=True,
|
||||||
|
credentials=None, # use credentials from the provider
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
quotas=quotas
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=False,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_spark(self) -> HostingProvider:
|
||||||
|
quota_unit = QuotaUnit.TOKENS
|
||||||
|
if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
|
||||||
|
quotas = [FreeHostingQuota()]
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=True,
|
||||||
|
credentials=None, # use credentials from the provider
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
quotas=quotas
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=False,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_zhipuai(self) -> HostingProvider:
|
||||||
|
quota_unit = QuotaUnit.TOKENS
|
||||||
|
if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
|
||||||
|
quotas = [FreeHostingQuota()]
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=True,
|
||||||
|
credentials=None, # use credentials from the provider
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
quotas=quotas
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostingProvider(
|
||||||
|
enabled=False,
|
||||||
|
quota_unit=quota_unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_moderation_config(self) -> HostedModerationConfig:
|
||||||
|
if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
|
||||||
|
and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
|
||||||
|
return HostedModerationConfig(
|
||||||
|
enabled=True,
|
||||||
|
providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||||
|
)
|
||||||
|
|
||||||
|
return HostedModerationConfig(
|
||||||
|
enabled=False
|
||||||
|
)
|
|
@ -1,18 +1,12 @@
|
||||||
import json
|
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
from core.index.vector_index.vector_index import VectorIndex
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_manager import ModelManager
|
||||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
|
||||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
|
||||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.provider import Provider, ProviderType
|
|
||||||
|
|
||||||
|
|
||||||
class IndexBuilder:
|
class IndexBuilder:
|
||||||
|
@ -22,10 +16,12 @@ class IndexBuilder:
|
||||||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||||
return None
|
return None
|
||||||
|
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
model_manager = ModelManager()
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_name=dataset.embedding_model
|
provider=dataset.embedding_model_provider,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = CacheEmbedding(embedding_model)
|
embeddings = CacheEmbedding(embedding_model)
|
||||||
|
|
|
@ -18,9 +18,11 @@ from core.data_loader.loader.notion import NotionLoader
|
||||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.generator.llm_generator import LLMGenerator
|
from core.generator.llm_generator import LLMGenerator
|
||||||
from core.index.index import IndexBuilder
|
from core.index.index import IndexBuilder
|
||||||
from core.model_providers.error import ProviderTokenNotInitError
|
from core.model_manager import ModelManager
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
from core.model_providers.models.entity.message import MessageType
|
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
@ -36,6 +38,7 @@ class IndexingRunner:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
|
self.model_manager = ModelManager()
|
||||||
|
|
||||||
def run(self, dataset_documents: List[DatasetDocument]):
|
def run(self, dataset_documents: List[DatasetDocument]):
|
||||||
"""Run the indexing process."""
|
"""Run the indexing process."""
|
||||||
|
@ -210,7 +213,7 @@ class IndexingRunner:
|
||||||
"""
|
"""
|
||||||
Estimate the indexing for the document.
|
Estimate the indexing for the document.
|
||||||
"""
|
"""
|
||||||
embedding_model = None
|
embedding_model_instance = None
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = Dataset.query.filter_by(
|
dataset = Dataset.query.filter_by(
|
||||||
id=dataset_id
|
id=dataset_id
|
||||||
|
@ -218,15 +221,17 @@ class IndexingRunner:
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset not found.')
|
raise ValueError('Dataset not found.')
|
||||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model_instance = self.model_manager.get_model_instance(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if indexing_technique == 'high_quality':
|
if indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
)
|
||||||
tokens = 0
|
tokens = 0
|
||||||
preview_texts = []
|
preview_texts = []
|
||||||
|
@ -255,32 +260,56 @@ class IndexingRunner:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if len(preview_texts) < 5:
|
if len(preview_texts) < 5:
|
||||||
preview_texts.append(document.page_content)
|
preview_texts.append(document.page_content)
|
||||||
if indexing_technique == 'high_quality' or embedding_model:
|
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||||
|
tokens += embedding_model_type_instance.get_num_tokens(
|
||||||
|
model=embedding_model_instance.model,
|
||||||
|
credentials=embedding_model_instance.credentials,
|
||||||
|
texts=[self.filter_string(document.page_content)]
|
||||||
|
)
|
||||||
|
|
||||||
if doc_form and doc_form == 'qa_model':
|
if doc_form and doc_form == 'qa_model':
|
||||||
text_generation_model = ModelFactory.get_text_generation_model(
|
model_instance = self.model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
if len(preview_texts) > 0:
|
if len(preview_texts) > 0:
|
||||||
# qa model document
|
# qa model document
|
||||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||||
doc_language)
|
doc_language)
|
||||||
document_qa_list = self.format_split_text(response)
|
document_qa_list = self.format_split_text(response)
|
||||||
|
price_info = model_type_instance.get_price(
|
||||||
|
model=model_instance.model,
|
||||||
|
credentials=model_instance.credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=total_segments * 2000,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"total_segments": total_segments * 20,
|
"total_segments": total_segments * 20,
|
||||||
"tokens": total_segments * 2000,
|
"tokens": total_segments * 2000,
|
||||||
"total_price": '{:f}'.format(
|
"total_price": '{:f}'.format(price_info.total_amount),
|
||||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
"currency": price_info.currency,
|
||||||
"currency": embedding_model.get_currency(),
|
|
||||||
"qa_preview": document_qa_list,
|
"qa_preview": document_qa_list,
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
if embedding_model_instance:
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||||
|
embedding_price_info = embedding_model_type_instance.get_price(
|
||||||
|
model=embedding_model_instance.model,
|
||||||
|
credentials=embedding_model_instance.credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
"total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
|
||||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
"currency": embedding_price_info.currency if embedding_model_instance else 'USD',
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,7 +319,7 @@ class IndexingRunner:
|
||||||
"""
|
"""
|
||||||
Estimate the indexing for the document.
|
Estimate the indexing for the document.
|
||||||
"""
|
"""
|
||||||
embedding_model = None
|
embedding_model_instance = None
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = Dataset.query.filter_by(
|
dataset = Dataset.query.filter_by(
|
||||||
id=dataset_id
|
id=dataset_id
|
||||||
|
@ -298,15 +327,17 @@ class IndexingRunner:
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset not found.')
|
raise ValueError('Dataset not found.')
|
||||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model_instance = self.model_manager.get_model_instance(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if indexing_technique == 'high_quality':
|
if indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING
|
||||||
)
|
)
|
||||||
# load data from notion
|
# load data from notion
|
||||||
tokens = 0
|
tokens = 0
|
||||||
|
@ -349,35 +380,63 @@ class IndexingRunner:
|
||||||
processing_rule=processing_rule
|
processing_rule=processing_rule
|
||||||
)
|
)
|
||||||
total_segments += len(documents)
|
total_segments += len(documents)
|
||||||
|
|
||||||
|
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if len(preview_texts) < 5:
|
if len(preview_texts) < 5:
|
||||||
preview_texts.append(document.page_content)
|
preview_texts.append(document.page_content)
|
||||||
if indexing_technique == 'high_quality' or embedding_model:
|
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
tokens += embedding_model_type_instance.get_num_tokens(
|
||||||
|
model=embedding_model_instance.model,
|
||||||
|
credentials=embedding_model_instance.credentials,
|
||||||
|
texts=[document.page_content]
|
||||||
|
)
|
||||||
|
|
||||||
if doc_form and doc_form == 'qa_model':
|
if doc_form and doc_form == 'qa_model':
|
||||||
text_generation_model = ModelFactory.get_text_generation_model(
|
model_instance = self.model_manager.get_default_model_instance(
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
if len(preview_texts) > 0:
|
if len(preview_texts) > 0:
|
||||||
# qa model document
|
# qa model document
|
||||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||||
doc_language)
|
doc_language)
|
||||||
document_qa_list = self.format_split_text(response)
|
document_qa_list = self.format_split_text(response)
|
||||||
|
|
||||||
|
price_info = model_type_instance.get_price(
|
||||||
|
model=model_instance.model,
|
||||||
|
credentials=model_instance.credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=total_segments * 2000,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_segments": total_segments * 20,
|
"total_segments": total_segments * 20,
|
||||||
"tokens": total_segments * 2000,
|
"tokens": total_segments * 2000,
|
||||||
"total_price": '{:f}'.format(
|
"total_price": '{:f}'.format(price_info.total_amount),
|
||||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
"currency": price_info.currency,
|
||||||
"currency": embedding_model.get_currency(),
|
|
||||||
"qa_preview": document_qa_list,
|
"qa_preview": document_qa_list,
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||||
|
embedding_price_info = embedding_model_type_instance.get_price(
|
||||||
|
model=embedding_model_instance.model,
|
||||||
|
credentials=embedding_model_instance.credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
"total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
|
||||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
"currency": embedding_price_info.currency if embedding_model_instance else 'USD',
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -656,25 +715,36 @@ class IndexingRunner:
|
||||||
"""
|
"""
|
||||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
embedding_model = None
|
embedding_model_instance = None
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model_instance = self.model_manager.get_model_instance(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# chunk nodes by chunk size
|
# chunk nodes by chunk size
|
||||||
indexing_start_at = time.perf_counter()
|
indexing_start_at = time.perf_counter()
|
||||||
tokens = 0
|
tokens = 0
|
||||||
chunk_size = 100
|
chunk_size = 100
|
||||||
|
|
||||||
|
embedding_model_type_instance = None
|
||||||
|
if embedding_model_instance:
|
||||||
|
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||||
|
|
||||||
for i in range(0, len(documents), chunk_size):
|
for i in range(0, len(documents), chunk_size):
|
||||||
# check document is paused
|
# check document is paused
|
||||||
self._check_document_paused_status(dataset_document.id)
|
self._check_document_paused_status(dataset_document.id)
|
||||||
chunk_documents = documents[i:i + chunk_size]
|
chunk_documents = documents[i:i + chunk_size]
|
||||||
if dataset.indexing_technique == 'high_quality' or embedding_model:
|
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||||
tokens += sum(
|
tokens += sum(
|
||||||
embedding_model.get_num_tokens(document.page_content)
|
embedding_model_type_instance.get_num_tokens(
|
||||||
|
embedding_model_instance.model,
|
||||||
|
embedding_model_instance.credentials,
|
||||||
|
[document.page_content]
|
||||||
|
)
|
||||||
for document in chunk_documents
|
for document in chunk_documents
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
from typing import Any, List, Dict
|
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
|
||||||
from langchain.schema import get_buffer_string, BaseMessage
|
|
||||||
|
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import Conversation, Message
|
|
||||||
|
|
||||||
|
|
||||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
|
||||||
conversation: Conversation
|
|
||||||
human_prefix: str = "Human"
|
|
||||||
ai_prefix: str = "Assistant"
|
|
||||||
model_instance: BaseLLM
|
|
||||||
memory_key: str = "chat_history"
|
|
||||||
max_token_limit: int = 2000
|
|
||||||
message_limit: int = 10
|
|
||||||
|
|
||||||
@property
|
|
||||||
def buffer(self) -> List[BaseMessage]:
|
|
||||||
"""String buffer of memory."""
|
|
||||||
app_model = self.conversation.app
|
|
||||||
|
|
||||||
# fetch limited messages desc, and return reversed
|
|
||||||
messages = db.session.query(Message).filter(
|
|
||||||
Message.conversation_id == self.conversation.id,
|
|
||||||
Message.answer != ''
|
|
||||||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
|
||||||
|
|
||||||
messages = list(reversed(messages))
|
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
|
|
||||||
|
|
||||||
chat_messages: List[PromptMessage] = []
|
|
||||||
for message in messages:
|
|
||||||
files = message.message_files
|
|
||||||
if files:
|
|
||||||
file_objs = message_file_parser.transform_message_files(
|
|
||||||
files, message.app_model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
|
|
||||||
chat_messages.append(PromptMessage(
|
|
||||||
content=message.query,
|
|
||||||
type=MessageType.USER,
|
|
||||||
files=prompt_message_files
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
|
||||||
|
|
||||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
|
||||||
|
|
||||||
if not chat_messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# prune the chat message if it exceeds the max token limit
|
|
||||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
|
||||||
if curr_buffer_length > self.max_token_limit:
|
|
||||||
pruned_memory = []
|
|
||||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
|
||||||
pruned_memory.append(chat_messages.pop(0))
|
|
||||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
|
||||||
|
|
||||||
return to_lc_messages(chat_messages)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_variables(self) -> List[str]:
|
|
||||||
"""Will always return list of memory variables.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.memory_key]
|
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Return history buffer."""
|
|
||||||
buffer: Any = self.buffer
|
|
||||||
if self.return_messages:
|
|
||||||
final_buffer: Any = buffer
|
|
||||||
else:
|
|
||||||
final_buffer = get_buffer_string(
|
|
||||||
buffer,
|
|
||||||
human_prefix=self.human_prefix,
|
|
||||||
ai_prefix=self.ai_prefix,
|
|
||||||
)
|
|
||||||
return {self.memory_key: final_buffer}
|
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
||||||
"""Nothing should be saved or changed"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Nothing to clear, got a memory like a vault."""
|
|
||||||
pass
|
|
|
@ -1,36 +0,0 @@
|
||||||
from typing import Any, List, Dict
|
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
|
||||||
from langchain.schema import get_buffer_string
|
|
||||||
|
|
||||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
|
||||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
|
||||||
|
|
||||||
|
|
||||||
class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
|
|
||||||
memory: ReadOnlyConversationTokenDBBufferSharedMemory
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_variables(self) -> List[str]:
|
|
||||||
"""Return memory variables."""
|
|
||||||
return self.memory.memory_variables
|
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
||||||
"""Load memory variables from memory."""
|
|
||||||
buffer: Any = self.memory.buffer
|
|
||||||
|
|
||||||
final_buffer = get_buffer_string(
|
|
||||||
buffer,
|
|
||||||
human_prefix=self.memory.human_prefix,
|
|
||||||
ai_prefix=self.memory.ai_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {self.memory.memory_key: final_buffer}
|
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
||||||
"""Nothing should be saved or changed"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Nothing to clear, got a memory like a vault."""
|
|
||||||
pass
|
|
109
api/core/memory/token_buffer_memory.py
Normal file
109
api/core/memory/token_buffer_memory.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
from core.file.message_file_parser import MessageFileParser
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \
|
||||||
|
AssistantPromptMessage, PromptMessageRole
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Conversation, Message
|
||||||
|
|
||||||
|
|
||||||
|
class TokenBufferMemory:
|
||||||
|
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||||
|
self.conversation = conversation
|
||||||
|
self.model_instance = model_instance
|
||||||
|
|
||||||
|
def get_history_prompt_messages(self, max_token_limit: int = 2000,
|
||||||
|
message_limit: int = 10) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Get history prompt messages.
|
||||||
|
:param max_token_limit: max token limit
|
||||||
|
:param message_limit: message limit
|
||||||
|
"""
|
||||||
|
app_record = self.conversation.app
|
||||||
|
|
||||||
|
# fetch limited messages, and return reversed
|
||||||
|
messages = db.session.query(Message).filter(
|
||||||
|
Message.conversation_id == self.conversation.id,
|
||||||
|
Message.answer != ''
|
||||||
|
).order_by(Message.created_at.desc()).limit(message_limit).all()
|
||||||
|
|
||||||
|
messages = list(reversed(messages))
|
||||||
|
message_file_parser = MessageFileParser(
|
||||||
|
tenant_id=app_record.tenant_id,
|
||||||
|
app_id=app_record.id
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_messages = []
|
||||||
|
for message in messages:
|
||||||
|
files = message.message_files
|
||||||
|
if files:
|
||||||
|
file_objs = message_file_parser.transform_message_files(
|
||||||
|
files, message.app_model_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||||
|
for file_obj in file_objs:
|
||||||
|
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||||
|
|
||||||
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||||
|
|
||||||
|
if not prompt_messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# prune the chat message if it exceeds the max token limit
|
||||||
|
provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
|
self.model_instance.model,
|
||||||
|
self.model_instance.credentials,
|
||||||
|
prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if curr_message_tokens > max_token_limit:
|
||||||
|
pruned_memory = []
|
||||||
|
while curr_message_tokens > max_token_limit and prompt_messages:
|
||||||
|
pruned_memory.append(prompt_messages.pop(0))
|
||||||
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
|
self.model_instance.model,
|
||||||
|
self.model_instance.credentials,
|
||||||
|
prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
def get_history_prompt_text(self, human_prefix: str = "Human",
|
||||||
|
ai_prefix: str = "Assistant",
|
||||||
|
max_token_limit: int = 2000,
|
||||||
|
message_limit: int = 10) -> str:
|
||||||
|
"""
|
||||||
|
Get history prompt text.
|
||||||
|
:param human_prefix: human prefix
|
||||||
|
:param ai_prefix: ai prefix
|
||||||
|
:param max_token_limit: max token limit
|
||||||
|
:param message_limit: message limit
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompt_messages = self.get_history_prompt_messages(
|
||||||
|
max_token_limit=max_token_limit,
|
||||||
|
message_limit=message_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
string_messages = []
|
||||||
|
for m in prompt_messages:
|
||||||
|
if m.role == PromptMessageRole.USER:
|
||||||
|
role = human_prefix
|
||||||
|
elif m.role == PromptMessageRole.ASSISTANT:
|
||||||
|
role = ai_prefix
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
message = f"{role}: {m.content}"
|
||||||
|
string_messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(string_messages)
|
209
api/core/model_manager.py
Normal file
209
api/core/model_manager.py
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
from typing import Optional, Union, Generator, cast, List, IO
|
||||||
|
|
||||||
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInstance:
|
||||||
|
"""
|
||||||
|
Model instance class
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
||||||
|
self._provider_model_bundle = provider_model_bundle
|
||||||
|
self.model = model
|
||||||
|
self.provider = provider_model_bundle.configuration.provider.provider
|
||||||
|
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||||
|
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
||||||
|
|
||||||
|
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
||||||
|
"""
|
||||||
|
Fetch credentials from provider model bundle
|
||||||
|
:param provider_model_bundle: provider model bundle
|
||||||
|
:param model: model name
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||||
|
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
if credentials is None:
|
||||||
|
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:param callbacks: callbacks
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||||
|
raise Exception(f"Model type instance is not LargeLanguageModel")
|
||||||
|
|
||||||
|
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||||
|
return self.model_type_instance.invoke(
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||||
|
raise Exception(f"Model type instance is not TextEmbeddingModel")
|
||||||
|
|
||||||
|
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||||
|
return self.model_type_instance.invoke(
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
texts=texts,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None) \
|
||||||
|
-> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, RerankModel):
|
||||||
|
raise Exception(f"Model type instance is not RerankModel")
|
||||||
|
|
||||||
|
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||||
|
return self.model_type_instance.invoke(
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
query=query,
|
||||||
|
docs=docs,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=top_n,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
||||||
|
-> bool:
|
||||||
|
"""
|
||||||
|
Invoke moderation model
|
||||||
|
|
||||||
|
:param text: text to moderate
|
||||||
|
:param user: unique user id
|
||||||
|
:return: false if text is safe, true otherwise
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, ModerationModel):
|
||||||
|
raise Exception(f"Model type instance is not ModerationModel")
|
||||||
|
|
||||||
|
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||||
|
return self.model_type_instance.invoke(
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
text=text,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||||
|
-> str:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param file: audio file
|
||||||
|
:param user: unique user id
|
||||||
|
:return: text for given audio file
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||||
|
raise Exception(f"Model type instance is not Speech2TextModel")
|
||||||
|
|
||||||
|
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||||
|
return self.model_type_instance.invoke(
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
file=file,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._provider_manager = ProviderManager()
|
||||||
|
|
||||||
|
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||||
|
"""
|
||||||
|
Get model instance
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param provider: provider name
|
||||||
|
:param model_type: model type
|
||||||
|
:param model: model name
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelInstance(provider_model_bundle, model)
|
||||||
|
|
||||||
|
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||||
|
"""
|
||||||
|
Get default model instance
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param model_type: model type
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
default_model_entity = self._provider_manager.get_default_model(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if not default_model_entity:
|
||||||
|
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
||||||
|
|
||||||
|
return self.get_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=default_model_entity.provider.provider,
|
||||||
|
model_type=model_type,
|
||||||
|
model=default_model_entity.model
|
||||||
|
)
|
|
@ -1,335 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from langchain.callbacks.base import Callbacks
|
|
||||||
|
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
|
||||||
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
|
|
||||||
from core.model_providers.models.base import BaseProviderModel
|
|
||||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
|
||||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.model_providers.models.moderation.base import BaseModeration
|
|
||||||
from core.model_providers.models.reranking.base import BaseReranking
|
|
||||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.provider import TenantDefaultModel
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFactory:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_text_generation_model_from_model_config(cls, tenant_id: str,
|
|
||||||
model_config: dict,
|
|
||||||
streaming: bool = False,
|
|
||||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
|
||||||
provider_name = model_config.get("provider")
|
|
||||||
model_name = model_config.get("name")
|
|
||||||
completion_params = model_config.get("completion_params", {})
|
|
||||||
|
|
||||||
return cls.get_text_generation_model(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_provider_name=provider_name,
|
|
||||||
model_name=model_name,
|
|
||||||
model_kwargs=ModelKwargs(
|
|
||||||
temperature=completion_params.get('temperature', 0),
|
|
||||||
max_tokens=completion_params.get('max_tokens', 256),
|
|
||||||
top_p=completion_params.get('top_p', 0),
|
|
||||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
|
||||||
presence_penalty=completion_params.get('presence_penalty', 0.1)
|
|
||||||
),
|
|
||||||
streaming=streaming,
|
|
||||||
callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_text_generation_model(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_provider_name: Optional[str] = None,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
model_kwargs: Optional[ModelKwargs] = None,
|
|
||||||
streaming: bool = False,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
deduct_quota: bool = True) -> Optional[BaseLLM]:
|
|
||||||
"""
|
|
||||||
get text generation model.
|
|
||||||
|
|
||||||
:param tenant_id: a string representing the ID of the tenant.
|
|
||||||
:param model_provider_name:
|
|
||||||
:param model_name:
|
|
||||||
:param model_kwargs:
|
|
||||||
:param streaming:
|
|
||||||
:param callbacks:
|
|
||||||
:param deduct_quota:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
is_default_model = False
|
|
||||||
if model_provider_name is None and model_name is None:
|
|
||||||
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
|
|
||||||
|
|
||||||
if not default_model:
|
|
||||||
raise LLMBadRequestError(f"Default model is not available. "
|
|
||||||
f"Please configure a Default System Reasoning Model "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
|
|
||||||
model_provider_name = default_model.provider_name
|
|
||||||
model_name = default_model.model_name
|
|
||||||
is_default_model = True
|
|
||||||
|
|
||||||
# get model provider
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
|
||||||
|
|
||||||
# init text generation model
|
|
||||||
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_instance = model_class(
|
|
||||||
model_provider=model_provider,
|
|
||||||
name=model_name,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
streaming=streaming,
|
|
||||||
callbacks=callbacks
|
|
||||||
)
|
|
||||||
except LLMBadRequestError as e:
|
|
||||||
if is_default_model:
|
|
||||||
raise LLMBadRequestError(f"Default model {model_name} is not available. "
|
|
||||||
f"Please check your model provider credentials.")
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if is_default_model or not deduct_quota:
|
|
||||||
model_instance.deduct_quota = False
|
|
||||||
|
|
||||||
return model_instance
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_embedding_model(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_provider_name: Optional[str] = None,
|
|
||||||
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
|
|
||||||
"""
|
|
||||||
get embedding model.
|
|
||||||
|
|
||||||
:param tenant_id: a string representing the ID of the tenant.
|
|
||||||
:param model_provider_name:
|
|
||||||
:param model_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if model_provider_name is None and model_name is None:
|
|
||||||
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
|
|
||||||
|
|
||||||
if not default_model:
|
|
||||||
raise LLMBadRequestError(f"Default model is not available. "
|
|
||||||
f"Please configure a Default Embedding Model "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
|
|
||||||
model_provider_name = default_model.provider_name
|
|
||||||
model_name = default_model.model_name
|
|
||||||
|
|
||||||
# get model provider
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
|
||||||
|
|
||||||
# init embedding model
|
|
||||||
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
|
|
||||||
return model_class(
|
|
||||||
model_provider=model_provider,
|
|
||||||
name=model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_reranking_model(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_provider_name: Optional[str] = None,
|
|
||||||
model_name: Optional[str] = None) -> Optional[BaseReranking]:
|
|
||||||
"""
|
|
||||||
get reranking model.
|
|
||||||
|
|
||||||
:param tenant_id: a string representing the ID of the tenant.
|
|
||||||
:param model_provider_name:
|
|
||||||
:param model_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if (model_provider_name is None or len(model_provider_name) == 0) and (model_name is None or len(model_name) == 0):
|
|
||||||
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
|
|
||||||
|
|
||||||
if not default_model:
|
|
||||||
raise LLMBadRequestError(f"Default model is not available. "
|
|
||||||
f"Please configure a Default Reranking Model "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
|
|
||||||
model_provider_name = default_model.provider_name
|
|
||||||
model_name = default_model.model_name
|
|
||||||
|
|
||||||
# get model provider
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
|
||||||
|
|
||||||
# init reranking model
|
|
||||||
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
|
|
||||||
return model_class(
|
|
||||||
model_provider=model_provider,
|
|
||||||
name=model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_speech2text_model(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_provider_name: Optional[str] = None,
|
|
||||||
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
|
|
||||||
"""
|
|
||||||
get speech to text model.
|
|
||||||
|
|
||||||
:param tenant_id: a string representing the ID of the tenant.
|
|
||||||
:param model_provider_name:
|
|
||||||
:param model_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if model_provider_name is None and model_name is None:
|
|
||||||
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
|
|
||||||
|
|
||||||
if not default_model:
|
|
||||||
raise LLMBadRequestError(f"Default model is not available. "
|
|
||||||
f"Please configure a Default Speech-to-Text Model "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
|
|
||||||
model_provider_name = default_model.provider_name
|
|
||||||
model_name = default_model.model_name
|
|
||||||
|
|
||||||
# get model provider
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
|
||||||
|
|
||||||
# init speech to text model
|
|
||||||
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
|
|
||||||
return model_class(
|
|
||||||
model_provider=model_provider,
|
|
||||||
name=model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_moderation_model(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_provider_name: str,
|
|
||||||
model_name: str) -> Optional[BaseModeration]:
|
|
||||||
"""
|
|
||||||
get moderation model.
|
|
||||||
|
|
||||||
:param tenant_id: a string representing the ID of the tenant.
|
|
||||||
:param model_provider_name:
|
|
||||||
:param model_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# get model provider
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
|
||||||
|
|
||||||
# init moderation model
|
|
||||||
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
|
|
||||||
return model_class(
|
|
||||||
model_provider=model_provider,
|
|
||||||
name=model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
|
|
||||||
"""
|
|
||||||
get default model of model type.
|
|
||||||
|
|
||||||
:param tenant_id:
|
|
||||||
:param model_type:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# get default model
|
|
||||||
default_model = db.session.query(TenantDefaultModel) \
|
|
||||||
.filter(
|
|
||||||
TenantDefaultModel.tenant_id == tenant_id,
|
|
||||||
TenantDefaultModel.model_type == model_type.value
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not default_model:
|
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
|
||||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
|
||||||
if not model_provider:
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_list = model_provider.get_supported_model_list(model_type)
|
|
||||||
if model_list:
|
|
||||||
model_info = model_list[0]
|
|
||||||
default_model = TenantDefaultModel(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=model_type.value,
|
|
||||||
provider_name=model_provider_name,
|
|
||||||
model_name=model_info['id']
|
|
||||||
)
|
|
||||||
db.session.add(default_model)
|
|
||||||
db.session.commit()
|
|
||||||
break
|
|
||||||
|
|
||||||
return default_model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def update_default_model(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_type: ModelType,
|
|
||||||
provider_name: str,
|
|
||||||
model_name: str) -> TenantDefaultModel:
|
|
||||||
"""
|
|
||||||
update default model of model type.
|
|
||||||
|
|
||||||
:param tenant_id:
|
|
||||||
:param model_type:
|
|
||||||
:param provider_name:
|
|
||||||
:param model_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
model_provider_name = ModelProviderFactory.get_provider_names()
|
|
||||||
if provider_name not in model_provider_name:
|
|
||||||
raise ValueError(f'Invalid provider name: {provider_name}')
|
|
||||||
|
|
||||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
|
|
||||||
|
|
||||||
if not model_provider:
|
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
|
||||||
|
|
||||||
model_list = model_provider.get_supported_model_list(model_type)
|
|
||||||
model_ids = [model['id'] for model in model_list]
|
|
||||||
if model_name not in model_ids:
|
|
||||||
raise ValueError(f'Invalid model name: {model_name}')
|
|
||||||
|
|
||||||
# get default model
|
|
||||||
default_model = db.session.query(TenantDefaultModel) \
|
|
||||||
.filter(
|
|
||||||
TenantDefaultModel.tenant_id == tenant_id,
|
|
||||||
TenantDefaultModel.model_type == model_type.value
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if default_model:
|
|
||||||
# update default model
|
|
||||||
default_model.provider_name = provider_name
|
|
||||||
default_model.model_name = model_name
|
|
||||||
db.session.commit()
|
|
||||||
else:
|
|
||||||
# create default model
|
|
||||||
default_model = TenantDefaultModel(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=model_type.value,
|
|
||||||
provider_name=provider_name,
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
db.session.add(default_model)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return default_model
|
|
|
@ -1,276 +0,0 @@
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
|
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
|
||||||
from core.model_providers.rules import provider_rules
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
|
|
||||||
|
|
||||||
DEFAULT_MODELS = {
|
|
||||||
ModelType.TEXT_GENERATION.value: {
|
|
||||||
'provider_name': 'openai',
|
|
||||||
'model_name': 'gpt-3.5-turbo',
|
|
||||||
},
|
|
||||||
ModelType.EMBEDDINGS.value: {
|
|
||||||
'provider_name': 'openai',
|
|
||||||
'model_name': 'text-embedding-ada-002',
|
|
||||||
},
|
|
||||||
ModelType.SPEECH_TO_TEXT.value: {
|
|
||||||
'provider_name': 'openai',
|
|
||||||
'model_name': 'whisper-1',
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderFactory:
|
|
||||||
@classmethod
|
|
||||||
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
|
|
||||||
if provider_name == 'openai':
|
|
||||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
|
||||||
return OpenAIProvider
|
|
||||||
elif provider_name == 'anthropic':
|
|
||||||
from core.model_providers.providers.anthropic_provider import AnthropicProvider
|
|
||||||
return AnthropicProvider
|
|
||||||
elif provider_name == 'minimax':
|
|
||||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
|
||||||
return MinimaxProvider
|
|
||||||
elif provider_name == 'spark':
|
|
||||||
from core.model_providers.providers.spark_provider import SparkProvider
|
|
||||||
return SparkProvider
|
|
||||||
elif provider_name == 'tongyi':
|
|
||||||
from core.model_providers.providers.tongyi_provider import TongyiProvider
|
|
||||||
return TongyiProvider
|
|
||||||
elif provider_name == 'wenxin':
|
|
||||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
|
||||||
return WenxinProvider
|
|
||||||
elif provider_name == 'zhipuai':
|
|
||||||
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
|
||||||
return ZhipuAIProvider
|
|
||||||
elif provider_name == 'chatglm':
|
|
||||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
|
||||||
return ChatGLMProvider
|
|
||||||
elif provider_name == 'baichuan':
|
|
||||||
from core.model_providers.providers.baichuan_provider import BaichuanProvider
|
|
||||||
return BaichuanProvider
|
|
||||||
elif provider_name == 'azure_openai':
|
|
||||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
|
||||||
return AzureOpenAIProvider
|
|
||||||
elif provider_name == 'replicate':
|
|
||||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
|
||||||
return ReplicateProvider
|
|
||||||
elif provider_name == 'huggingface_hub':
|
|
||||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
|
||||||
return HuggingfaceHubProvider
|
|
||||||
elif provider_name == 'xinference':
|
|
||||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
|
||||||
return XinferenceProvider
|
|
||||||
elif provider_name == 'openllm':
|
|
||||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
|
||||||
return OpenLLMProvider
|
|
||||||
elif provider_name == 'localai':
|
|
||||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
|
||||||
return LocalAIProvider
|
|
||||||
elif provider_name == 'cohere':
|
|
||||||
from core.model_providers.providers.cohere_provider import CohereProvider
|
|
||||||
return CohereProvider
|
|
||||||
elif provider_name == 'jina':
|
|
||||||
from core.model_providers.providers.jina_provider import JinaProvider
|
|
||||||
return JinaProvider
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_provider_names(cls):
|
|
||||||
"""
|
|
||||||
Returns a list of provider names.
|
|
||||||
"""
|
|
||||||
return list(provider_rules.keys())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_provider_rules(cls):
|
|
||||||
"""
|
|
||||||
Returns a list of provider rules.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return provider_rules
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_provider_rule(cls, provider_name: str):
|
|
||||||
"""
|
|
||||||
Returns provider rule.
|
|
||||||
"""
|
|
||||||
return provider_rules[provider_name]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
|
|
||||||
"""
|
|
||||||
get preferred model provider.
|
|
||||||
|
|
||||||
:param tenant_id: a string representing the ID of the tenant.
|
|
||||||
:param model_provider_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# get preferred provider
|
|
||||||
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
|
|
||||||
if not preferred_provider or not preferred_provider.is_valid:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# init model provider
|
|
||||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
|
||||||
return model_provider_class(provider=preferred_provider)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_preferred_type_by_preferred_model_provider(cls,
|
|
||||||
tenant_id: str,
|
|
||||||
model_provider_name: str,
|
|
||||||
preferred_model_provider: TenantPreferredModelProvider):
|
|
||||||
"""
|
|
||||||
get preferred provider type by preferred model provider.
|
|
||||||
|
|
||||||
:param model_provider_name:
|
|
||||||
:param preferred_model_provider:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if not preferred_model_provider:
|
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
|
||||||
support_provider_types = model_provider_rules['support_provider_types']
|
|
||||||
|
|
||||||
if ProviderType.CUSTOM.value in support_provider_types:
|
|
||||||
custom_provider = db.session.query(Provider) \
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name == model_provider_name,
|
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
||||||
Provider.is_valid == True
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if custom_provider:
|
|
||||||
return ProviderType.CUSTOM.value
|
|
||||||
|
|
||||||
model_provider = cls.get_model_provider_class(model_provider_name)
|
|
||||||
|
|
||||||
if ProviderType.SYSTEM.value in support_provider_types \
|
|
||||||
and model_provider.is_provider_type_system_supported():
|
|
||||||
return ProviderType.SYSTEM.value
|
|
||||||
elif ProviderType.CUSTOM.value in support_provider_types:
|
|
||||||
return ProviderType.CUSTOM.value
|
|
||||||
else:
|
|
||||||
return preferred_model_provider.preferred_provider_type
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
|
|
||||||
"""
|
|
||||||
get preferred provider of tenant.
|
|
||||||
|
|
||||||
:param tenant_id:
|
|
||||||
:param model_provider_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# get preferred provider type
|
|
||||||
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
|
|
||||||
|
|
||||||
# get providers by preferred provider type
|
|
||||||
providers = db.session.query(Provider) \
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name == model_provider_name,
|
|
||||||
Provider.provider_type == preferred_provider_type
|
|
||||||
).all()
|
|
||||||
|
|
||||||
no_system_provider = False
|
|
||||||
if preferred_provider_type == ProviderType.SYSTEM.value:
|
|
||||||
quota_type_to_provider_dict = {}
|
|
||||||
for provider in providers:
|
|
||||||
quota_type_to_provider_dict[provider.quota_type] = provider
|
|
||||||
|
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
|
||||||
for quota_type_enum in ProviderQuotaType:
|
|
||||||
quota_type = quota_type_enum.value
|
|
||||||
if quota_type in model_provider_rules['system_config']['supported_quota_types']:
|
|
||||||
if quota_type in quota_type_to_provider_dict.keys():
|
|
||||||
provider = quota_type_to_provider_dict[quota_type]
|
|
||||||
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
|
||||||
return provider
|
|
||||||
elif quota_type == ProviderQuotaType.TRIAL.value:
|
|
||||||
try:
|
|
||||||
provider = Provider(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider_name=model_provider_name,
|
|
||||||
provider_type=ProviderType.SYSTEM.value,
|
|
||||||
is_valid=True,
|
|
||||||
quota_type=ProviderQuotaType.TRIAL.value,
|
|
||||||
quota_limit=model_provider_rules['system_config']['quota_limit'],
|
|
||||||
quota_used=0
|
|
||||||
)
|
|
||||||
db.session.add(provider)
|
|
||||||
db.session.commit()
|
|
||||||
except IntegrityError:
|
|
||||||
db.session.rollback()
|
|
||||||
provider = db.session.query(Provider) \
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name == model_provider_name,
|
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
||||||
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if provider.quota_limit == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
no_system_provider = True
|
|
||||||
|
|
||||||
if no_system_provider:
|
|
||||||
providers = db.session.query(Provider) \
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name == model_provider_name,
|
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value
|
|
||||||
).all()
|
|
||||||
|
|
||||||
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
|
|
||||||
if providers:
|
|
||||||
return providers[0]
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
provider = Provider(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider_name=model_provider_name,
|
|
||||||
provider_type=ProviderType.CUSTOM.value,
|
|
||||||
is_valid=False
|
|
||||||
)
|
|
||||||
db.session.add(provider)
|
|
||||||
db.session.commit()
|
|
||||||
except IntegrityError:
|
|
||||||
db.session.rollback()
|
|
||||||
provider = db.session.query(Provider) \
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name == model_provider_name,
|
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
|
|
||||||
"""
|
|
||||||
get preferred provider type of tenant.
|
|
||||||
|
|
||||||
:param tenant_id:
|
|
||||||
:param model_provider_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
|
||||||
.filter(
|
|
||||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
|
||||||
TenantPreferredModelProvider.provider_name == model_provider_name
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user