import logging import time import click from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @shared_task(queue="dataset") def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index :param dataset_id: dataset_id :param action: action Usage: deal_dataset_vector_index_task.delay(dataset_id, action) """ logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: raise Exception("Dataset not found") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": dataset_documents = ( db.session.query(DatasetDocument) .filter( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) .all() ) if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() for dataset_document in dataset_documents: try: # add from vector index segments = ( db.session.query(DocumentSegment) .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) if segments: documents = [] for segment in segments: document = Document( page_content=segment.content, metadata={ "doc_id": segment.index_node_id, "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() elif action == "update": dataset_documents = ( db.session.query(DatasetDocument) .filter( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) .all() ) # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() # clean index index_processor.clean(dataset, None, with_keywords=False) for dataset_document in dataset_documents: # update from vector index try: segments = ( db.session.query(DocumentSegment) .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) if segments: documents = [] for segment in segments: document = Document( page_content=segment.content, metadata={ "doc_id": segment.index_node_id, "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() end_at = time.perf_counter() logging.info( click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") ) except Exception: logging.exception("Deal dataset vector index failed")