From 5e66a60f1c938f6db5b5d95fbef7ac64e6855b82 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 2 Apr 2024 20:46:24 +0800 Subject: [PATCH] add embedding cache and clean embedding cache job (#3087) Co-authored-by: jyong --- api/core/embedding/cached_embedding.py | 74 ++++++++++++------- api/extensions/ext_celery.py | 4 +- ...d7385a7b66_add_embeddings_provider_name.py | 34 +++++++++ api/models/dataset.py | 9 ++- 4 files changed, 91 insertions(+), 30 deletions(-) create mode 100644 api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 7498a07559..11dfe8dc15 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -12,6 +12,7 @@ from core.rag.datasource.entity.embedding import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper +from models.dataset import Embedding logger = logging.getLogger(__name__) @@ -23,32 +24,55 @@ class CacheEmbedding(Embeddings): def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" - text_embeddings = [] - try: - model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) - max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 - for i in range(0, len(texts), max_chunks): - batch_texts = texts[i:i + max_chunks] + # use doc embedding cache or store if not exists + text_embeddings = [None for _ in range(len(texts))] + embedding_queue_indices = [] + for i, text in enumerate(texts): + hash = helper.generate_text_hash(text) + embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider).first() + if embedding: + text_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + if embedding_queue_indices: + embedding_queue_texts = [texts[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) + max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 + for i in range(0, len(embedding_queue_texts), max_chunks): + batch_texts = embedding_queue_texts[i:i + max_chunks] - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding( + texts=batch_texts, + user=self._user + ) - for vector in embedding_result.embeddings: - try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() - text_embeddings.append(normalized_embedding) - except IntegrityError: - db.session.rollback() - except Exception as e: - logging.exception('Failed to add embedding to redis') - - except Exception as ex: - logger.error('Failed to embed documents: ', ex) - raise ex + for vector in embedding_result.embeddings: + try: + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception as e: + logging.exception('Failed transform embedding: ', e) + for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = embedding + hash = helper.generate_text_hash(texts[i]) + embedding_cache = Embedding(model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider) + embedding_cache.set_embedding(embedding) + db.session.add(embedding_cache) + db.session.commit() + except Exception as ex: + db.session.rollback() + logger.error('Failed to embed documents: ', ex) + raise ex return text_embeddings @@ -61,8 +85,6 @@ class CacheEmbedding(Embeddings): if embedding: redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) - - try: embedding_result = self._model_instance.invoke_text_embedding( texts=[text], diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 89a0924763..fcb99a9e83 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -46,11 +46,11 @@ def init_app(app: Flask) -> Celery: beat_schedule = { 'clean_embedding_cache_task': { 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', - 'schedule': timedelta(days=7), + 'schedule': timedelta(days=1), }, 'clean_unused_datasets_task': { 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(minutes=3), + 'schedule': timedelta(days=1), } } celery_app.conf.update( diff --git a/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py b/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py new file mode 100644 index 0000000000..1ee01381d8 --- /dev/null +++ b/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py @@ -0,0 +1,34 @@ +"""add-embeddings-provider-name + +Revision ID: a8d7385a7b66 +Revises: 17b5ab037c40 +Create Date: 2024-04-02 12:17:22.641525 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a8d7385a7b66' +down_revision = '17b5ab037c40' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.add_column(sa.Column('provider_name', sa.String(length=40), server_default=sa.text("''::character varying"), nullable=False)) + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash', 'provider_name']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash']) + batch_op.drop_column('provider_name') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index f90fc9abb7..5a893ad009 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -123,6 +123,7 @@ class Dataset(db.Model): normalized_dataset_id = dataset_id.replace("-", "_") return f'Vector_index_{normalized_dataset_id}_Node' + class DatasetProcessRule(db.Model): __tablename__ = 'dataset_process_rules' __table_args__ = ( @@ -443,7 +444,8 @@ class DatasetKeywordTable(db.Model): id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) dataset_id = db.Column(UUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False, server_default=db.text("'database'::character varying")) + data_source_type = db.Column(db.String(255), nullable=False, + server_default=db.text("'database'::character varying")) @property def keyword_table_dict(self): @@ -457,6 +459,7 @@ class DatasetKeywordTable(db.Model): if isinstance(node_idxs, list): dct[keyword] = set(node_idxs) return dct + # get dataset dataset = Dataset.query.filter_by( id=self.dataset_id @@ -481,7 +484,7 @@ class Embedding(db.Model): __tablename__ = 'embeddings' __table_args__ = ( db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('model_name', 'hash', name='embedding_hash_idx') + db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx') ) id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) @@ -490,6 +493,8 @@ class Embedding(db.Model): hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + provider_name = db.Column(db.String(40), nullable=False, + server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)