From 724e053732e8d3bc85db4308d20d5186619de35c Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:21:26 +0800 Subject: [PATCH] Fix/qdrant data issue (#1203) Co-authored-by: jyong --- api/commands.py | 151 +++++++++--------- api/core/index/vector_index/base.py | 8 +- api/core/index/vector_index/qdrant.py | 131 +++++++-------- .../index/vector_index/qdrant_vector_index.py | 13 ++ .../clean_when_dataset_deleted.py | 3 +- api/tasks/clean_dataset_task.py | 10 +- api/tasks/deal_dataset_vector_index_task.py | 4 +- 7 files changed, 171 insertions(+), 149 deletions(-) diff --git a/api/commands.py b/api/commands.py index 8b6f5455f1..d5a0e47858 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,12 +3,13 @@ import json import math import random import string +import threading import time import uuid import click from tqdm import tqdm -from flask import current_app +from flask import current_app, Flask from langchain.embeddings import OpenAIEmbeddings from werkzeug.exceptions import NotFound @@ -456,92 +457,92 @@ def update_qdrant_indexes(): @click.command('normalization-collections', help='restore all collections in one') def normalization_collections(): click.echo(click.style('Start normalization collections.', fg='green')) - normalization_count = 0 - + normalization_count = [] page = 1 while True: try: datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100) except NotFound: break - + datasets_result = datasets.items page += 1 - for dataset in datasets: - if not dataset.collection_binding_id: - try: - click.echo('restore dataset index: {}'.format(dataset.id)) - try: - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - except Exception: - provider = Provider( - id='provider_id', - tenant_id=dataset.tenant_id, - provider_name='openai', - provider_type=ProviderType.CUSTOM.value, - encrypted_config=json.dumps({'openai_api_key': 'TEST'}), - is_valid=True, - ) - model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", - model_provider=model_provider) - embeddings = CacheEmbedding(embedding_model) - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, - DatasetCollectionBinding.model_name == embedding_model.name). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + for i in range(0, len(datasets_result), 5): + threads = [] + sub_datasets = datasets_result[i:i + 5] + for dataset in sub_datasets: + document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'normalization_count': normalization_count + }) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() - if not dataset_collection_binding: - dataset_collection_binding = DatasetCollectionBinding( - provider_name=embedding_model.model_provider.provider_name, - model_name=embedding_model.name, - collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' - ) - db.session.add(dataset_collection_binding) - db.session.commit() + click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) - from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig - index = QdrantVectorIndex( - dataset=dataset, - config=QdrantConfig( - endpoint=current_app.config.get('QDRANT_URL'), - api_key=current_app.config.get('QDRANT_API_KEY'), - root_path=current_app.root_path - ), - embeddings=embeddings - ) - if index: - index.restore_dataset_in_one(dataset, dataset_collection_binding) - else: - click.echo('passed.') +def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): + with flask_app.app_context(): + try: + click.echo('restore dataset index: {}'.format(dataset.id)) + try: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except Exception: + provider = Provider( + id='provider_id', + tenant_id=dataset.tenant_id, + provider_name='openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'openai_api_key': 'TEST'}), + is_valid=True, + ) + model_provider = OpenAIProvider(provider=provider) + embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", + model_provider=model_provider) + embeddings = CacheEmbedding(embedding_model) + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, + DatasetCollectionBinding.model_name == embedding_model.name). \ + order_by(DatasetCollectionBinding.created_at). \ + first() - original_index = QdrantVectorIndex( - dataset=dataset, - config=QdrantConfig( - endpoint=current_app.config.get('QDRANT_URL'), - api_key=current_app.config.get('QDRANT_API_KEY'), - root_path=current_app.root_path - ), - embeddings=embeddings - ) - if original_index: - original_index.delete_original_collection(dataset, dataset_collection_binding) - normalization_count += 1 - else: - click.echo('passed.') - except Exception as e: - click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) - continue + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=embedding_model.model_provider.provider_name, + model_name=embedding_model.name, + collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' + ) + db.session.add(dataset_collection_binding) + db.session.commit() - click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green')) + from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + + index = QdrantVectorIndex( + dataset=dataset, + config=QdrantConfig( + endpoint=current_app.config.get('QDRANT_URL'), + api_key=current_app.config.get('QDRANT_API_KEY'), + root_path=current_app.root_path + ), + embeddings=embeddings + ) + if index: + # index.delete_by_group_id(dataset.id) + index.restore_dataset_in_one(dataset, dataset_collection_binding) + else: + click.echo('passed.') + normalization_count.append(1) + except Exception as e: + click.echo( + click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index 1e59135f37..60f092d409 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex): def delete_by_group_id(self, group_id: str) -> None: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() + if self.dataset.collection_binding_id: + vector_store.delete_by_group_id(group_id) + else: + vector_store.delete() def delete(self) -> None: vector_store = self._get_vector_store() @@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex): if documents: try: - self.create_with_collection_name(documents, dataset_collection_binding.collection_name) + self.add_texts(documents) except Exception as e: raise e diff --git a/api/core/index/vector_index/qdrant.py b/api/core/index/vector_index/qdrant.py index 56f7a2ce34..5b9736a0b5 100644 --- a/api/core/index/vector_index/qdrant.py +++ b/api/core/index/vector_index/qdrant.py @@ -1390,70 +1390,12 @@ class Qdrant(VectorStore): path=path, **kwargs, ) - try: - # Skip any validation in case of forced collection recreate. - if force_recreate: - raise ValueError - - # Get the vector configuration of the existing collection and vector, if it - # was specified. If the old configuration does not match the current one, - # an exception is being thrown. - collection_info = client.get_collection(collection_name=collection_name) - current_vector_config = collection_info.config.params.vectors - if isinstance(current_vector_config, dict) and vector_name is not None: - if vector_name not in current_vector_config: - raise QdrantException( - f"Existing Qdrant collection {collection_name} does not " - f"contain vector named {vector_name}. Did you mean one of the " - f"existing vectors: {', '.join(current_vector_config.keys())}? " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - current_vector_config = current_vector_config.get( - vector_name - ) # type: ignore[assignment] - elif isinstance(current_vector_config, dict) and vector_name is None: - raise QdrantException( - f"Existing Qdrant collection {collection_name} uses named vectors. " - f"If you want to reuse it, please set `vector_name` to any of the " - f"existing named vectors: " - f"{', '.join(current_vector_config.keys())}." # noqa - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - elif ( - not isinstance(current_vector_config, dict) and vector_name is not None - ): - raise QdrantException( - f"Existing Qdrant collection {collection_name} doesn't use named " - f"vectors. If you want to reuse it, please set `vector_name` to " - f"`None`. If you want to recreate the collection, set " - f"`force_recreate` parameter to `True`." - ) - - # Check if the vector configuration has the same dimensionality. - if current_vector_config.size != vector_size: # type: ignore[union-attr] - raise QdrantException( - f"Existing Qdrant collection is configured for vectors with " - f"{current_vector_config.size} " # type: ignore[union-attr] - f"dimensions. Selected embeddings are {vector_size}-dimensional. " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - - current_distance_func = ( - current_vector_config.distance.name.upper() # type: ignore[union-attr] - ) - if current_distance_func != distance_func: - raise QdrantException( - f"Existing Qdrant collection is configured for " - f"{current_vector_config.distance} " # type: ignore[union-attr] - f"similarity. Please set `distance_func` parameter to " - f"`{distance_func}` if you want to reuse it. If you want to " - f"recreate the collection, set `force_recreate` parameter to " - f"`True`." - ) - except (UnexpectedResponse, RpcError, ValueError): + all_collection_name = [] + collections_response = client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if collection_name not in all_collection_name: vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[distance_func], @@ -1481,6 +1423,67 @@ class Qdrant(VectorStore): timeout=timeout, # type: ignore[arg-type] ) is_new_collection = True + if force_recreate: + raise ValueError + + # Get the vector configuration of the existing collection and vector, if it + # was specified. If the old configuration does not match the current one, + # an exception is being thrown. + collection_info = client.get_collection(collection_name=collection_name) + current_vector_config = collection_info.config.params.vectors + if isinstance(current_vector_config, dict) and vector_name is not None: + if vector_name not in current_vector_config: + raise QdrantException( + f"Existing Qdrant collection {collection_name} does not " + f"contain vector named {vector_name}. Did you mean one of the " + f"existing vectors: {', '.join(current_vector_config.keys())}? " + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." + ) + current_vector_config = current_vector_config.get( + vector_name + ) # type: ignore[assignment] + elif isinstance(current_vector_config, dict) and vector_name is None: + raise QdrantException( + f"Existing Qdrant collection {collection_name} uses named vectors. " + f"If you want to reuse it, please set `vector_name` to any of the " + f"existing named vectors: " + f"{', '.join(current_vector_config.keys())}." # noqa + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." + ) + elif ( + not isinstance(current_vector_config, dict) and vector_name is not None + ): + raise QdrantException( + f"Existing Qdrant collection {collection_name} doesn't use named " + f"vectors. If you want to reuse it, please set `vector_name` to " + f"`None`. If you want to recreate the collection, set " + f"`force_recreate` parameter to `True`." + ) + + # Check if the vector configuration has the same dimensionality. + if current_vector_config.size != vector_size: # type: ignore[union-attr] + raise QdrantException( + f"Existing Qdrant collection is configured for vectors with " + f"{current_vector_config.size} " # type: ignore[union-attr] + f"dimensions. Selected embeddings are {vector_size}-dimensional. " + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." + ) + + current_distance_func = ( + current_vector_config.distance.name.upper() # type: ignore[union-attr] + ) + if current_distance_func != distance_func: + raise QdrantException( + f"Existing Qdrant collection is configured for " + f"{current_vector_config.distance} " # type: ignore[union-attr] + f"similarity. Please set `distance_func` parameter to " + f"`{distance_func}` if you want to reuse it. If you want to " + f"recreate the collection, set `force_recreate` parameter to " + f"`True`." + ) qdrant = cls( client=client, collection_name=collection_name, diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index 2be77609a6..732a10b0ae 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex): ], )) + def delete(self) -> None: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + from qdrant_client.http import models + vector_store.del_texts(models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self.dataset.id), + ), + ], + )) def _is_origin(self): if self.dataset.index_struct_dict: diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index e9975c92bc..93181ea161 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task @dataset_was_deleted.connect def handle(sender, **kwargs): dataset = sender - clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct) + clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, + dataset.index_struct, dataset.collection_binding_id) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index dea9059b00..8f5e37f49b 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase @shared_task(queue='dataset') -def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str): +def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, + index_struct: str, collection_binding_id: str): """ Clean dataset when dataset deleted. :param dataset_id: dataset id :param tenant_id: tenant id :param indexing_technique: indexing technique :param index_struct: index struct dict + :param collection_binding_id: collection binding id Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ @@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, id=dataset_id, tenant_id=tenant_id, indexing_technique=indexing_technique, - index_struct=index_struct + index_struct=index_struct, + collection_binding_id=collection_binding_id ) - documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() @@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, if dataset.indexing_technique == 'high_quality': vector_index = IndexBuilder.get_default_high_quality_index(dataset) try: - vector_index.delete() + vector_index.delete_by_group_id(dataset.id) except Exception: logging.exception("Delete doc index failed when dataset deleted.") diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 96d1dc9096..6a3b52a40b 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): raise Exception('Dataset not found') if action == "remove": - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) - index.delete() + index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) + index.delete_by_group_id(dataset.id) elif action == "add": dataset_documents = db.session.query(DatasetDocument).filter( DatasetDocument.dataset_id == dataset_id,