Feat/dataset service api (#1245)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong 2023-09-27 16:06:32 +08:00 committed by GitHub
parent 54ff03c35d
commit 46154c6705
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 1636 additions and 906 deletions

View File

@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)

View File

@ -19,41 +19,13 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig, Site from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
def _get_app(app_id, tenant_id): def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id):
class AppListApi(Resource): class AppListApi(Resource):
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@ -238,18 +181,6 @@ class AppListApi(Resource):
class AppTemplateApi(Resource): class AppTemplateApi(Resource):
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
@setup_required @setup_required
@login_required @login_required
@ -268,38 +199,6 @@ class AppTemplateApi(Resource):
class AppApi(Resource): class AppApi(Resource):
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
@setup_required @setup_required
@login_required @login_required

View File

@ -13,107 +13,14 @@ from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
conversation_message_detail_fields, conversation_with_summary_pagination_fields
from libs.helper import TimestampField, datetime_string, uuid_value from libs.helper import TimestampField, datetime_string, uuid_value
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message, MessageAnnotation, Conversation from models.model import Message, MessageAnnotation, Conversation
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
class CompletionConversationDetailApi(Resource): class CompletionConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(conversation_detail_fields) @marshal_with(conversation_message_detail_fields)
def get(self, app_id, conversation_id): def get(self, app_id, conversation_id):
app_id = str(app_id) app_id = str(app_id)
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
class ChatConversationApi(Resource): class ChatConversationApi(Resource):
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
class ChatConversationDetailApi(Resource): class ChatConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
@setup_required @setup_required
@login_required @login_required

View File

@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required from core.login.login import login_required
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db from extensions.ext_database import db
@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.message_service import MessageService from services.message_service import MessageService
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
class ChatMessageListApi(Resource): class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = { message_infinite_scroll_pagination_fields = {

View File

@ -8,26 +8,11 @@ from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.app_fields import app_site_fields
from libs.helper import supported_language from libs.helper import supported_language
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Site from models.model import Site
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

View File

@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from models.dataset import Document from models.dataset import Document
from models.source import DataSourceBinding from models.source import DataSourceBinding
@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30)
class DataSourceApi(Resource): class DataSourceApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required @setup_required
@login_required @login_required
@ -131,28 +101,6 @@ class DataSourceApi(Resource):
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
@setup_required @setup_required
@login_required @login_required

View File

@ -1,6 +1,9 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask import request import flask_restful
from flask import request, current_app
from flask_login import current_user from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.login.login import login_required from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
from models.model import UploadFile from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}
def _validate_name(name): def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
@ -82,7 +56,8 @@ class DatasetListApi(Resource):
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0: # if len(valid_model_list) == 0:
# raise ProviderNotInitializeError( # raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider " # f"No Embedding Model available. Please configure a valid provider "
@ -157,7 +132,8 @@ class DatasetApi(Resource):
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_service = ProviderService()
# get valid model list # get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = [] model_names = []
for valid_model in valid_model_list: for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json') parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
@setup_required @setup_required
@login_required @login_required
@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
class DatasetIndexingStatusApi(Resource): class DatasetIndexingStatusApi(Resource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required @setup_required
@login_required @login_required
@ -400,16 +347,97 @@ class DatasetIndexingStatusApi(Resource):
DocumentSegment.status != 're_segment').count() DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
} }
return data return data
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = 'dataset-'
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
all()
return {"items": keys}
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
count()
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code='max_keys_exceeded'
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, api_key_id):
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
key = db.session.query(ApiToken). \
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id). \
first()
if key is None:
flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {'result': 'success'}, 204
class DatasetApiBaseUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
return {
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
else request.host_url.rstrip('/')) + '/v1'
}
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')

View File

@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMBadRequestError LLMBadRequestError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset from models.dataset import DatasetProcessRule, Dataset
@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
@setup_required @setup_required
@login_required @login_required
@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required @setup_required
@login_required @login_required
@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
document.indexing_status = 'paused' document.indexing_status = 'paused'
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
} }
@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
class DocumentIndexingStatusApi(DocumentResource): class DocumentIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
@setup_required @setup_required
@login_required @login_required
@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource):
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
document.indexing_status = 'paused' document.indexing_status = 'paused'
return marshal(document, self.document_status_fields) return marshal(document, document_status_fields)
class DocumentDetailApi(DocumentResource): class DocumentDetailApi(DocumentResource):

View File

@ -3,7 +3,7 @@ import uuid
from datetime import datetime from datetime import datetime
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
import services import services
@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required from core.login.login import login_required
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
from libs.helper import TimestampField from libs.helper import TimestampField
@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
import pandas as pd import pandas as pd
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required

View File

@ -1,28 +1,19 @@
import datetime
import hashlib
import tempfile
import chardet
import time
import uuid
from pathlib import Path
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request, current_app
from flask_login import current_user
import services
from core.login.login import login_required from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields from flask_restful import Resource, marshal_with, fields
from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_loader.file_extractor import FileExtractor from fields.file_fields import upload_config_fields, file_fields
from extensions.ext_storage import storage
from libs.helper import TimestampField from services.file_service import FileService
from extensions.ext_database import db
from models.model import UploadFile
cache = TTLCache(maxsize=None, ttl=30) cache = TTLCache(maxsize=None, ttl=30)
@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource): class FileApi(Resource):
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
}
@setup_required @setup_required
@login_required @login_required
@ -48,16 +35,6 @@ class FileApi(Resource):
'batch_count_limit': batch_count_limit 'batch_count_limit': batch_count_limit
}, 200 }, 200
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -73,45 +50,13 @@ class FileApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
try:
file_content = file.read() upload_file = FileService.upload_file(file)
file_size = len(file_content) except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 except services.errors.file.UnsupportedFileTypeError:
if file_size > file_size_limit:
message = "({file_size} > {file_size_limit})"
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file, 201 return upload_file, 201
@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
text = FileService.get_file_preview(file_id)
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text} return {'content': text}

View File

@ -2,7 +2,7 @@ import logging
from flask_login import current_user from flask_login import current_user
from core.login.login import login_required from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal, fields from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
import services import services
@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError LLMBadRequestError
from libs.helper import TimestampField from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}
class HitTestingApi(Resource): class HitTestingApi(Resource):

View File

@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != 'chat': if app_model.mode != 'chat':

View File

@ -11,32 +11,11 @@ from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@login_required @login_required

View File

@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -26,45 +27,6 @@ from services.message_service import MessageService
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
feedback_fields = {
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):

View File

@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
conversation_with_model_config_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField,
'model_config': fields.Raw,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class UniversalChatConversationListApi(UniversalChatResource): class UniversalChatConversationListApi(UniversalChatResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
def get(self, universal_app): def get(self, universal_app):
app_model = universal_app app_model = universal_app
@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
class UniversalChatConversationRenameApi(UniversalChatResource): class UniversalChatConversationRenameApi(UniversalChatResource):
@marshal_with(conversation_fields) @marshal_with(conversation_with_model_config_fields)
def post(self, universal_app, c_id): def post(self, universal_app, c_id):
app_model = universal_app app_model = universal_app
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -9,4 +9,4 @@ api = ExternalApi(bp)
from .app import completion, app, conversation, message, audio from .app import completion, app, conversation, message, audio
from .dataset import document from .dataset import document, segment, dataset

View File

@ -8,25 +8,11 @@ from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
import services import services
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationApi(AppApiResource): class ConversationApi(AppApiResource):
@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource): class ConversationDetailApi(AppApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
class ConversationRenameApi(AppApiResource): class ConversationRenameApi(AppApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()

View File

@ -0,0 +1,84 @@
from flask import request
from flask_restful import reqparse, marshal
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from models.account import Account, TenantAccountJoin
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 40 characters.')
return name
class DatasetApi(DatasetApiResource):
"""Resource for get datasets."""
def get(self, tenant_id):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
"""Resource for datasets."""
def post(self, tenant_id):
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=('high_quality', 'economy'),
help='Invalid indexing technique.')
args = parser.parse_args()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
api.add_resource(DatasetApi, '/datasets')

View File

@ -1,114 +1,291 @@
import datetime import datetime
import json
import uuid import uuid
from flask import current_app from flask import current_app, request
from flask_restful import reqparse from flask_restful import reqparse, marshal
from sqlalchemy import desc
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.file_service import FileService
class DocumentListApi(DatasetApiResource): class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents.""" """Resource for documents."""
def post(self, dataset): def post(self, tenant_id, dataset_id):
"""Create document.""" """Create document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('text', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('doc_type', type=str, location='json') parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_metadata', type=dict, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset.indexing_technique: if not dataset:
raise DatasetNotInitedError("Dataset indexing technique must be set.") raise ValueError('Dataset is not exist.')
doc_type = args.get('doc_type') if not dataset.indexing_technique and not args['indexing_technique']:
doc_metadata = args.get('doc_metadata') raise ValueError('indexing_technique is required.')
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: upload_file = FileService.upload_text(args.get('text'), args.get('name'))
raise ValueError('Invalid doc_type.') data_source = {
'type': 'upload_file',
# user uuid as file name 'info_list': {
file_uuid = str(uuid.uuid4()) 'data_source_type': 'upload_file',
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt' 'file_info_list': {
'file_ids': [upload_file.id]
# save file to storage }
storage.save(file_key, args.get('text'))
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=dataset.tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=args.get('name') + '.txt',
size=len(args.get('text')),
extension='txt',
mime_type='text/plain',
created_by=dataset.created_by,
created_at=datetime.datetime.utcnow(),
used=True,
used_by=dataset.created_by,
used_at=datetime.datetime.utcnow()
)
db.session.add(upload_file)
db.session.commit()
document_data = {
'data_source': {
'type': 'upload_file',
'info': [
{
'upload_file_id': upload_file.id
}
]
} }
} }
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=document_data, document_data=args,
account=dataset.created_by_account, account=current_user,
dataset_process_rule=dataset.latest_process_rule, dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api' created_from='api'
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
document = documents[0] document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {} documents_and_batch_fields = {
'document': marshal(document, document_fields),
for key, value_type in metadata_schema.items(): 'batch': batch
value = doc_metadata.get(key) }
if value is not None and isinstance(value, value_type): return documents_and_batch_fields, 200
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.datetime.utcnow()
db.session.commit()
return {'id': document.id}
class DocumentApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource):
def delete(self, dataset, document_id): """Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if args['text']:
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
data_source = {
'type': 'upload_file',
'info_list': {
'data_source_type': 'upload_file',
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
# save file info
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if 'file' in request.files:
# save file info
file = request.files['file']
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document.""" """Delete document."""
document_id = str(document_id) document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.') raise DocumentIndexingError('Cannot delete document during indexing.')
return {'result': 'success'}, 204 return {'result': 'success'}, 200
api.add_resource(DocumentListApi, '/documents') class DocumentListApi(DatasetApiResource):
api.add_resource(DocumentApi, '/documents/<uuid:document_id>') def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('keyword', default=None, type=str)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
query = Document.query.filter_by(
dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f'%{search}%'
query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at))
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
'data': marshal(documents, document_fields),
'has_more': len(documents) == limit,
'limit': limit,
'total': paginated_documents.total,
'page': page
}
return response
class DocumentIndexingStatusApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# get documents
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound('Documents not found.')
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
return data
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')

View File

@ -1,20 +1,73 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = 'high_quality_dataset_only'
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = 'dataset_not_initialized'
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException): class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = 'archived_document_immutable' error_code = 'archived_document_immutable'
description = "Cannot operate when document was archived." description = "The archived document is not editable."
code = 403 code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = 'dataset_name_duplicate'
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = 'invalid_action'
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = 'document_already_finished'
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException): class DocumentIndexingError(BaseHTTPException):
error_code = 'document_indexing' error_code = 'document_indexing'
description = "Cannot operate document during indexing." description = "The document is being processed and cannot be edited."
code = 403 code = 400
class DatasetNotInitedError(BaseHTTPException): class InvalidMetadataError(BaseHTTPException):
error_code = 'dataset_not_inited' error_code = 'invalid_metadata'
description = "The dataset is still being initialized or indexing. Please wait a moment." description = "The metadata content is incorrect. Please check and verify."
code = 403 code = 400

View File

@ -0,0 +1,59 @@
from flask_login import current_user
from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset
from services.dataset_service import DocumentService, SegmentService
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')

View File

@ -2,11 +2,14 @@
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from flask import request from flask import request, current_app
from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from core.login.login import _get_user
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
@ -43,12 +46,24 @@ def validate_dataset_token(view=None):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('dataset') api_token = validate_and_get_api_token('dataset')
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first() .filter(Tenant.id == api_token.tenant_id) \
if not dataset: .filter(TenantAccountJoin.tenant_id == Tenant.id) \
raise NotFound() .filter(TenantAccountJoin.role == 'owner') \
.one_or_none()
return view(dataset, *args, **kwargs) if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
else:
raise Unauthorized("Tenant owner account is not exist.")
else:
raise Unauthorized("Tenant is not exist.")
return view(api_token.tenant_id, *args, **kwargs)
return decorated return decorated
if view: if view:

View File

@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
from controllers.web import api from controllers.web import api
from controllers.web.error import NotChatAppError from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(WebApiResource): class ConversationListApi(WebApiResource):
@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
class ConversationRenameApi(WebApiResource): class ConversationRenameApi(WebApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()

View File

@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search" SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']

View File

@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: List[str]): def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel): class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)

0
api/fields/__init__.py Normal file
View File

138
api/fields/app_fields.py Normal file
View File

@ -0,0 +1,138 @@
from flask_restful import fields
from libs.helper import TimestampField
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}

View File

@ -0,0 +1,182 @@
from flask_restful import fields
from libs.helper import TimestampField
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
conversation_message_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_with_summary_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_with_summary_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
}
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
simple_conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(simple_conversation_fields))
}
conversation_with_model_config_fields = {
**simple_conversation_fields,
'model_config': fields.Raw,
}
conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
}

View File

@ -0,0 +1,65 @@
from flask_restful import fields
from libs.helper import TimestampField
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}

View File

@ -0,0 +1,43 @@
from flask_restful import fields
from libs.helper import TimestampField
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}

View File

@ -0,0 +1,76 @@
from flask_restful import fields
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}

18
api/fields/file_fields.py Normal file
View File

@ -0,0 +1,18 @@
from flask_restful import fields
from libs.helper import TimestampField
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
}
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}

View File

@ -0,0 +1,41 @@
from flask_restful import fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}

View File

@ -0,0 +1,25 @@
from flask_restful import fields
from libs.helper import TimestampField
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}

View File

@ -0,0 +1,43 @@
from flask_restful import fields
from libs.helper import TimestampField
feedback_fields = {
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}

View File

@ -0,0 +1,32 @@
from flask_restful import fields
from libs.helper import TimestampField
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}

View File

@ -0,0 +1,36 @@
"""add_tenant_id_in_api_token
Revision ID: 2e9819ca5b28
Revises: 6e2cfb077b04
Create Date: 2023-09-22 15:41:01.243183
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '2e9819ca5b28'
down_revision = 'ab23c11305d4'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
# ### end Alembic commands ###

View File

@ -629,12 +629,13 @@ class ApiToken(db.Model):
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='api_token_pkey'), db.PrimaryKeyConstraint('id', name='api_token_pkey'),
db.Index('api_token_app_id_type_idx', 'app_id', 'type'), db.Index('api_token_app_id_type_idx', 'app_id', 'type'),
db.Index('api_token_token_idx', 'token', 'type') db.Index('api_token_token_idx', 'token', 'type'),
db.Index('api_token_tenant_idx', 'tenant_id', 'type')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=True) app_id = db.Column(UUID, nullable=True)
dataset_id = db.Column(UUID, nullable=True) tenant_id = db.Column(UUID, nullable=True)
type = db.Column(db.String(16), nullable=False) type = db.Column(db.String(16), nullable=False)
token = db.Column(db.String(255), nullable=False) token = db.Column(db.String(255), nullable=False)
last_used_at = db.Column(db.DateTime, nullable=True) last_used_at = db.Column(db.DateTime, nullable=True)

View File

@ -96,7 +96,7 @@ class DatasetService:
embedding_model = None embedding_model = None
if indexing_technique == 'high_quality': if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id tenant_id=tenant_id
) )
dataset = Dataset(name=name, indexing_technique=indexing_technique) dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config) # dataset = Dataset(name=name, provider=provider, config=config)
@ -477,6 +477,7 @@ class DocumentService:
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if 'original_document_id' in document_data and document_data["original_document_id"]: if 'original_document_id' in document_data and document_data["original_document_id"]:
@ -626,6 +627,9 @@ class DocumentService:
document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available': if document.display_status != 'available':
raise ValueError("Document is not available") raise ValueError("Document is not available")
# update document name
if 'name' in document_data and document_data['name']:
document.name = document_data['name']
# save process rule # save process rule
if 'process_rule' in document_data and document_data['process_rule']: if 'process_rule' in document_data and document_data['process_rule']:
process_rule = document_data["process_rule"] process_rule = document_data["process_rule"]
@ -767,7 +771,7 @@ class DocumentService:
return dataset, documents, batch return dataset, documents, batch
@classmethod @classmethod
def document_create_args_validate(cls, args: dict): def document_create_args_validate(cls, args: dict):
if 'original_document_id' not in args or not args['original_document_id']: if 'original_document_id' not in args or not args['original_document_id']:
DocumentService.data_source_args_validate(args) DocumentService.data_source_args_validate(args)
DocumentService.process_rule_args_validate(args) DocumentService.process_rule_args_validate(args)
@ -1014,6 +1018,66 @@ class SegmentService:
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
return segment return segment
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
pre_segment_data_list = []
segment_data_list = []
for segment_item in segments:
content = segment_item['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status='completed',
indexing_at=datetime.datetime.utcnow(),
completed_at=datetime.datetime.utcnow(),
created_by=current_user.id
)
if document.doc_form == 'qa_model':
segment_document.answer = segment_item['answer']
db.session.add(segment_document)
segment_data_list.append(segment_document)
pre_segment_data = {
'segment': segment_document,
'keywords': segment_item['keywords']
}
pre_segment_data_list.append(pre_segment_data)
try:
# save vector index
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
except Exception as e:
logging.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = datetime.datetime.utcnow()
segment_document.status = 'error'
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@classmethod @classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id) indexing_cache_key = 'segment_{}_indexing'.format(segment.id)

View File

@ -1,7 +1,7 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
__all__ = [ __all__ = [
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
'app', 'completion', 'audio' 'app', 'completion', 'audio', 'file'
] ]
from . import * from . import *

View File

@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError
class FileNotExistsError(BaseServiceError): class FileNotExistsError(BaseServiceError):
pass pass
class FileTooLargeError(BaseServiceError):
description = "{message}"
class UnsupportedFileTypeError(BaseServiceError):
pass

View File

@ -0,0 +1,123 @@
import datetime
import hashlib
import time
import uuid
from cachetools import TTLCache
from flask import request, current_app
from flask_login import current_user
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.data_loader.file_extractor import FileExtractor
from extensions.ext_storage import storage
from extensions.ext_database import db
from models.model import UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
PREVIEW_WORDS_LIMIT = 3000
cache = TTLCache(maxsize=None, ttl=30)
class FileService:
@staticmethod
def upload_file(file: FileStorage) -> UploadFile:
# read file content
file_content = file.read()
# get file size
file_size = len(file_content)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = f'File size exceeded. {file_size} > {file_size_limit}'
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
# save file to storage
storage.save(file_key, text.encode('utf-8'))
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=text_name + '.txt',
size=len(text),
extension='txt',
mime_type='text/plain',
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=True,
used_by=current_user.id,
used_at=datetime.datetime.utcnow()
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def get_file_preview(file_id: str) -> str:
# get file storage key
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return text

View File

@ -35,6 +35,32 @@ class VectorService:
else: else:
index.add_texts([document]) index.add_texts([document])
@classmethod
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
documents = []
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
# save keyword index
keyword_index = IndexBuilder.get_index(dataset, 'economy')
if keyword_index:
keyword_index.multi_create_segment_keywords(pre_segment_data_list)
@classmethod @classmethod
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset): def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
# update segment index task # update segment index task