mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
fix: score_threshold handling in vector search methods (#8356)
This commit is contained in:
parent
a45ac6ab98
commit
08c486452f
|
@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
|
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
|
|
|
@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
|
|
|
@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
|
|||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
|
|||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
|
|
|
@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
|
|||
metadata = {}
|
||||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
@ -200,7 +200,7 @@ class OracleVector(BaseVector):
|
|||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
|
@ -212,7 +212,7 @@ class OracleVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
|
|
|
@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
|
|||
metadata = record.meta
|
||||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
@ -144,7 +144,7 @@ class PGVector(BaseVector):
|
|||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
|
|
|
@ -333,13 +333,13 @@ class QdrantVector(BaseVector):
|
|||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", 0.0),
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
|
|
|
@ -230,7 +230,7 @@ class RelytVector(BaseVector):
|
|||
# Organize results.
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
|
|
@ -153,7 +153,7 @@ class TencentVector(BaseVector):
|
|||
limit=kwargs.get("top_k", 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
|
|
|
@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
|
|||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
filter = kwargs.get("filter")
|
||||
distance = 1 - score_threshold
|
||||
|
||||
|
|
|
@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
|
|
Loading…
Reference in New Issue
Block a user