import asyncio import concurrent import datetime import json import logging import re import threading import time import uuid from concurrent.futures import ThreadPoolExecutor from multiprocessing import Process from typing import Optional, List, cast import openai from billiard.pool import Pool from flask import current_app, Flask from flask_login import current_user from langchain.embeddings import OpenAIEmbeddings from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from core.data_loader.file_extractor import FileExtractor from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatesetDocumentStore from core.embedding.cached_embedding import CacheEmbedding from core.generator.llm_generator import LLMGenerator from core.index.index import IndexBuilder from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex from core.llm.error import ProviderTokenNotInitError from core.llm.llm_builder import LLMBuilder from core.llm.streamable_open_ai import StreamableOpenAI from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.llm.token_calculator import TokenCalculator from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from models.dataset import Document as DatasetDocument from models.dataset import Dataset, DocumentSegment, DatasetProcessRule from models.model import UploadFile from models.source import DataSourceBinding class IndexingRunner: def __init__(self, embedding_model_name: str = "text-embedding-ada-002"): self.storage = storage self.embedding_model_name = embedding_model_name def run(self, dataset_documents: List[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: try: # get dataset dataset = Dataset.query.filter_by( id=dataset_document.dataset_id ).first() if not dataset: raise ValueError("no dataset found") # load file text_docs = self._load_data(dataset_document) # get the process rule processing_rule = db.session.query(DatasetProcessRule). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() # get splitter splitter = self._get_splitter(processing_rule) # split to documents documents = self._step_split( text_docs=text_docs, splitter=splitter, dataset=dataset, dataset_document=dataset_document, processing_rule=processing_rule ) # new_documents = [] # for document in documents: # response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content) # document_qa_list = self.format_split_text(response) # for result in document_qa_list: # document = Document(page_content=result['question'], metadata={'source': result['answer']}) # new_documents.append(document) # build index self._build_index( dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = 'error' dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = 'error' dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() def format_split_text(self, text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" matches = re.findall(regex, text, re.MULTILINE) result = [] for match in matches: q = match[0] a = match[1] if q and a: result.append({ "question": q, "answer": re.sub(r"\n\s*", "\n", a.strip()) }) return result def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" try: # get dataset dataset = Dataset.query.filter_by( id=dataset_document.dataset_id ).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( dataset_id=dataset.id, document_id=dataset_document.id ).all() db.session.delete(document_segments) db.session.commit() # load file text_docs = self._load_data(dataset_document) # get the process rule processing_rule = db.session.query(DatasetProcessRule). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() # get splitter splitter = self._get_splitter(processing_rule) # split to documents documents = self._step_split( text_docs=text_docs, splitter=splitter, dataset=dataset, dataset_document=dataset_document, processing_rule=processing_rule ) # build index self._build_index( dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = 'error' dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = 'error' dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" try: # get dataset dataset = Dataset.query.filter_by( id=dataset_document.dataset_id ).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( dataset_id=dataset.id, document_id=dataset_document.id ).all() documents = [] if document_segments: for document_segment in document_segments: # transform segment to node if document_segment.status != "completed": document = Document( page_content=document_segment.content, metadata={ "doc_id": document_segment.index_node_id, "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, } ) documents.append(document) # build index self._build_index( dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = 'error' dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = 'error' dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, doc_form: str = None) -> dict: """ Estimate the indexing for the document. """ tokens = 0 preview_texts = [] total_segments = 0 for file_detail in file_details: # load data from file text_docs = FileExtractor.load(file_detail) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter splitter = self._get_splitter(processing_rule) # split to documents documents = self._split_to_documents_for_estimate( text_docs=text_docs, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, self.filter_string(document.page_content)) if doc_form and doc_form == 'qa_model': if len(preview_texts) > 0: # qa model document llm: StreamableOpenAI = LLMBuilder.to_llm( tenant_id=current_user.current_tenant_id, model_name='gpt-3.5-turbo', max_tokens=2000 ) response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) document_qa_list = self.format_split_text(response) return { "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), "currency": TokenCalculator.get_currency(self.embedding_model_name), "qa_preview": document_qa_list, "preview": preview_texts } return { "total_segments": total_segments, "tokens": tokens, "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "currency": TokenCalculator.get_currency(self.embedding_model_name), "preview": preview_texts } def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: """ Estimate the indexing for the document. """ # load data from notion tokens = 0 preview_texts = [] total_segments = 0 for notion_info in notion_info_list: workspace_id = notion_info['workspace_id'] data_source_binding = DataSourceBinding.query.filter( db.and_( DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.provider == 'notion', DataSourceBinding.disabled == False, DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: raise ValueError('Data source binding not found.') for page in notion_info['pages']: loader = NotionLoader( notion_access_token=data_source_binding.access_token, notion_workspace_id=workspace_id, notion_obj_id=page['page_id'], notion_page_type=page['type'] ) documents = loader.load() processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter splitter = self._get_splitter(processing_rule) # split to documents documents = self._split_to_documents_for_estimate( text_docs=documents, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) if doc_form and doc_form == 'qa_model': if len(preview_texts) > 0: # qa model document llm: StreamableOpenAI = LLMBuilder.to_llm( tenant_id=current_user.current_tenant_id, model_name='gpt-3.5-turbo', max_tokens=2000 ) response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) document_qa_list = self.format_split_text(response) return { "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), "currency": TokenCalculator.get_currency(self.embedding_model_name), "qa_preview": document_qa_list, "preview": preview_texts } return { "total_segments": total_segments, "tokens": tokens, "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "currency": TokenCalculator.get_currency(self.embedding_model_name), "preview": preview_texts } def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] data_source_info = dataset_document.data_source_info_dict text_docs = [] if dataset_document.data_source_type == 'upload_file': if not data_source_info or 'upload_file_id' not in data_source_info: raise ValueError("no upload file found") file_detail = db.session.query(UploadFile). \ filter(UploadFile.id == data_source_info['upload_file_id']). \ one_or_none() text_docs = FileExtractor.load(file_detail) elif dataset_document.data_source_type == 'notion_import': loader = NotionLoader.from_document(dataset_document) text_docs = loader.load() # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), DatasetDocument.parsing_completed_at: datetime.datetime.utcnow() } ) # replace doc id to document model id text_docs = cast(List[Document], text_docs) for text_doc in text_docs: # remove invalid symbol text_doc.page_content = self.filter_string(text_doc.page_content) text_doc.metadata['document_id'] = dataset_document.id text_doc.metadata['dataset_id'] = dataset_document.dataset_id return text_docs def filter_string(self, text): text = re.sub(r'<\|', '<', text) text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text) return text def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ if processing_rule.mode == "custom": # The user-defined segmentation rule rules = json.loads(processing_rule.rules) segmentation = rules["segmentation"] if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: raise ValueError("Custom segment length should be between 50 and 1000.") separator = segmentation["separator"] if separator: separator = separator.replace('\\n', '\n') character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=segmentation["max_tokens"], chunk_overlap=0, fixed_separator=separator, separators=["\n\n", "。", ".", " ", ""] ) else: # Automatic segmentation character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_overlap=0, separators=["\n\n", "。", ".", " ", ""] ) return character_splitter def _step_split(self, text_docs: List[Document], splitter: TextSplitter, dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ -> List[Document]: """ Split the text documents into documents and save them to the document segment. """ documents = self._split_to_documents( text_docs=text_docs, splitter=splitter, processing_rule=processing_rule, tenant_id=dataset.tenant_id, document_form=dataset_document.doc_form ) # save node to document segment doc_store = DatesetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, embedding_model_name=self.embedding_model_name, document_id=dataset_document.id ) # add document segments doc_store.add_documents(documents) # update document status to indexing cur_time = datetime.datetime.utcnow() self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, } ) # update segment status to indexing self._update_segments_by_document( dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", DocumentSegment.indexing_at: datetime.datetime.utcnow() } ) return documents def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]: """ Split the text documents into nodes. """ all_documents = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) text_doc.page_content = document_text # parse document to nodes documents = splitter.split_documents([text_doc]) split_documents = [] llm: StreamableOpenAI = LLMBuilder.to_llm( tenant_id=tenant_id, model_name='gpt-3.5-turbo', max_tokens=2000 ) threads = [] for doc in documents: document_format_thread = threading.Thread(target=self.format_document, kwargs={ 'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'document_form': document_form}) threads.append(document_format_thread) document_format_thread.start() for thread in threads: thread.join() all_documents.extend(split_documents) return all_documents def format_document(self, llm: StreamableOpenAI, document_node, split_documents: List, document_form: str): print(document_node.page_content) format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): return format_documents if document_form == 'text_model': # text model document doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) document_node.metadata['doc_id'] = doc_id document_node.metadata['doc_hash'] = hash format_documents.append(document_node) elif document_form == 'qa_model': try: # qa model document response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(result['question']) qa_document.metadata['answer'] = result['answer'] qa_document.metadata['doc_id'] = doc_id qa_document.metadata['doc_hash'] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: logging.error(str(e)) split_documents.extend(format_documents) def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule) -> List[Document]: """ Split the text documents into nodes. """ all_documents = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) text_doc.page_content = document_text # parse document to nodes documents = splitter.split_documents([text_doc]) split_documents = [] for document in documents: if document.page_content is None or not document.page_content.strip(): continue doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) document.metadata['doc_id'] = doc_id document.metadata['doc_hash'] = hash split_documents.append(document) all_documents.extend(split_documents) return all_documents def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: """ Clean the document text according to the processing rules. """ if processing_rule.mode == "automatic": rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} if 'pre_processing_rules' in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces pattern = r'\n{3,}' text = re.sub(pattern, '\n\n', text) pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' text = re.sub(pattern, ' ', text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' text = re.sub(pattern, '', text) # Remove URL pattern = r'https?://[^\s]+' text = re.sub(pattern, '', text) return text def format_split_text(self, text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式 matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果 result = [] # 存储最终的结果 for match in matches: q = match[0] a = match[1] if q and a: # 如果Q和A都存在,就将其添加到结果中 result.append({ "question": q, "answer": re.sub(r"\n\s*", "\n", a.strip()) }) return result def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: """ Build the index for the document. """ vector_index = IndexBuilder.get_index(dataset, 'high_quality') keyword_table_index = IndexBuilder.get_index(dataset, 'economy') # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 chunk_size = 100 for i in range(0, len(documents), chunk_size): # check document is paused self._check_document_paused_status(dataset_document.id) chunk_documents = documents[i:i + chunk_size] tokens += sum( TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) for document in chunk_documents ) # save vector index if vector_index: vector_index.add_texts(chunk_documents) # save keyword index keyword_table_index.add_texts(chunk_documents) document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.status == "indexing" ).update({ DocumentSegment.status: "completed", DocumentSegment.completed_at: datetime.datetime.utcnow() }) db.session.commit() indexing_end_at = time.perf_counter() # update document status to completed self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, DatasetDocument.completed_at: datetime.datetime.utcnow(), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, } ) def _check_document_paused_status(self, document_id: str): indexing_cache_key = 'document_{}_is_paused'.format(document_id) result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedException() def _update_document_index_status(self, document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None) -> None: """ Update the document indexing status. """ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: raise DocumentIsPausedException() update_params = { DatasetDocument.indexing_status: after_indexing_status } if extra_update_params: update_params.update(extra_update_params) DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.commit() def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: """ Update the document segment by document id. """ DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() class DocumentIsPausedException(Exception): pass