Compatible with unique index conflicts (#3183)

This commit is contained in:
Jyong 2024-04-09 02:16:19 +08:00 committed by GitHub
parent ca3e2e6cc0
commit 2e4dec365d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,7 +41,8 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = [] embedding_queue_embeddings = []
try: try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) model_schema = model_type_instance.get_model_schema(self._model_instance.model,
self._model_instance.credentials)
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
for i in range(0, len(embedding_queue_texts), max_chunks): for i in range(0, len(embedding_queue_texts), max_chunks):
@ -61,17 +62,20 @@ class CacheEmbedding(Embeddings):
except Exception as e: except Exception as e:
logging.exception('Failed transform embedding: ', e) logging.exception('Failed transform embedding: ', e)
cache_embeddings = [] cache_embeddings = []
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): try:
text_embeddings[i] = embedding for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
hash = helper.generate_text_hash(texts[i]) text_embeddings[i] = embedding
if hash not in cache_embeddings: hash = helper.generate_text_hash(texts[i])
embedding_cache = Embedding(model_name=self._model_instance.model, if hash not in cache_embeddings:
hash=hash, embedding_cache = Embedding(model_name=self._model_instance.model,
provider_name=self._model_instance.provider) hash=hash,
embedding_cache.set_embedding(embedding) provider_name=self._model_instance.provider)
db.session.add(embedding_cache) embedding_cache.set_embedding(embedding)
cache_embeddings.append(hash) db.session.add(embedding_cache)
db.session.commit() cache_embeddings.append(hash)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex: except Exception as ex:
db.session.rollback() db.session.rollback()
logger.error('Failed to embed documents: ', ex) logger.error('Failed to embed documents: ', ex)