fix: score_threshold handling in vector search methods (#8356)

This commit is contained in:
-LAN- 2024-09-13 14:24:35 +08:00 committed by GitHub
parent a45ac6ab98
commit 08c486452f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 17 additions and 17 deletions

View File

@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest( request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id, dbinstance_id=self.config.instance_id,
region_id=self.config.region_id, region_id=self.config.region_id,
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest( request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id, dbinstance_id=self.config.instance_id,
region_id=self.config.region_id, region_id=self.config.region_id,

View File

@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
ids: list[str] = results["ids"][0] ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0] documents: list[str] = results["documents"][0]

View File

@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold: if score > score_threshold:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)

View File

@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
for result in results[0]: for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value) metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"] metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold: if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = ( where_str = (
f"WHERE dist < {1 - score_threshold}" f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0

View File

@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
metadata = {} metadata = {}
metadata["score"] = hit["_score"] metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold: if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -200,7 +200,7 @@ class OracleVector(BaseVector):
[numpy.array(query_vector)], [numpy.array(query_vector)],
) )
docs = [] docs = []
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur: for record in cur:
metadata, text, distance = record metadata, text, distance = record
score = 1 - distance score = 1 - distance
@ -212,7 +212,7 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0: if len(query) > 0:
# Check which language the query is in # Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+") zh_pattern = re.compile("[\u4e00-\u9fa5]+")

View File

@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
metadata = record.meta metadata = record.meta
score = 1 - dis score = 1 - dis
metadata["score"] = score metadata["score"] = score
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold: if score > score_threshold:
doc = Document(page_content=record.text, metadata=metadata) doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -144,7 +144,7 @@ class PGVector(BaseVector):
(json.dumps(query_vector),), (json.dumps(query_vector),),
) )
docs = [] docs = []
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur: for record in cur:
metadata, text, distance = record metadata, text, distance = record
score = 1 - distance score = 1 - distance

View File

@ -333,13 +333,13 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
with_payload=True, with_payload=True,
with_vectors=True, with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0), score_threshold=float(kwargs.get("score_threshold") or 0.0),
) )
docs = [] docs = []
for result in results: for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {} metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold # duplicate check score threshold
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold: if result.score > score_threshold:
metadata["score"] = result.score metadata["score"] = result.score
doc = Document( doc = Document(

View File

@ -230,7 +230,7 @@ class RelytVector(BaseVector):
# Organize results. # Organize results.
docs = [] docs = []
for document, score in results: for document, score in results:
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold: if 1 - score > score_threshold:
docs.append(document) docs.append(document)
return docs return docs

View File

@ -153,7 +153,7 @@ class TencentVector(BaseVector):
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout, timeout=self._client_config.timeout,
) )
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold) return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

View File

@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
filter = kwargs.get("filter") filter = kwargs.get("filter")
distance = 1 - score_threshold distance = 1 - score_threshold

View File

@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold # check score threshold
if score > score_threshold: if score > score_threshold:
doc.metadata["score"] = score doc.metadata["score"] = score