fix: score_threshold handling in vector search methods (#8356)

This commit is contained in:
-LAN- 2024-09-13 14:24:35 +08:00 committed by GitHub
parent a45ac6ab98
commit 08c486452f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 17 additions and 17 deletions

View File

@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,

View File

@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]

View File

@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0

View File

@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
metadata = {}
metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -200,7 +200,7 @@ class OracleVector(BaseVector):
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
@ -212,7 +212,7 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")

View File

@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
metadata = record.meta
score = 1 - dis
metadata["score"] = score
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)

View File

@ -144,7 +144,7 @@ class PGVector(BaseVector):
(json.dumps(query_vector),),
)
docs = []
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance

View File

@ -333,13 +333,13 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0),
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(

View File

@ -230,7 +230,7 @@ class RelytVector(BaseVector):
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold:
docs.append(document)
return docs

View File

@ -153,7 +153,7 @@ class TencentVector(BaseVector):
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

View File

@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold

View File

@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score