mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
fix: score_threshold handling in vector search methods (#8356)
This commit is contained in:
parent
a45ac6ab98
commit
08c486452f
|
@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
dbinstance_id=self.config.instance_id,
|
dbinstance_id=self.config.instance_id,
|
||||||
region_id=self.config.region_id,
|
region_id=self.config.region_id,
|
||||||
|
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
dbinstance_id=self.config.instance_id,
|
dbinstance_id=self.config.instance_id,
|
||||||
region_id=self.config.region_id,
|
region_id=self.config.region_id,
|
||||||
|
|
|
@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
collection = self._client.get_or_create_collection(self._collection_name)
|
collection = self._client.get_or_create_collection(self._collection_name)
|
||||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
|
||||||
ids: list[str] = results["ids"][0]
|
ids: list[str] = results["ids"][0]
|
||||||
documents: list[str] = results["documents"][0]
|
documents: list[str] = results["documents"][0]
|
||||||
|
|
|
@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for doc, score in docs_and_scores:
|
for doc, score in docs_and_scores:
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if score > score_threshold:
|
if score > score_threshold:
|
||||||
doc.metadata["score"] = score
|
doc.metadata["score"] = score
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
|
@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
|
||||||
for result in results[0]:
|
for result in results[0]:
|
||||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||||
metadata["score"] = result["distance"]
|
metadata["score"] = result["distance"]
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if result["distance"] > score_threshold:
|
if result["distance"] > score_threshold:
|
||||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
|
@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
|
||||||
|
|
||||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
where_str = (
|
where_str = (
|
||||||
f"WHERE dist < {1 - score_threshold}"
|
f"WHERE dist < {1 - score_threshold}"
|
||||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||||
|
|
|
@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
metadata["score"] = hit["_score"]
|
metadata["score"] = hit["_score"]
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if hit["_score"] > score_threshold:
|
if hit["_score"] > score_threshold:
|
||||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
|
@ -200,7 +200,7 @@ class OracleVector(BaseVector):
|
||||||
[numpy.array(query_vector)],
|
[numpy.array(query_vector)],
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
for record in cur:
|
for record in cur:
|
||||||
metadata, text, distance = record
|
metadata, text, distance = record
|
||||||
score = 1 - distance
|
score = 1 - distance
|
||||||
|
@ -212,7 +212,7 @@ class OracleVector(BaseVector):
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
# just not implement fetch by score_threshold now, may be later
|
# just not implement fetch by score_threshold now, may be later
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if len(query) > 0:
|
if len(query) > 0:
|
||||||
# Check which language the query is in
|
# Check which language the query is in
|
||||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||||
|
|
|
@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
|
||||||
metadata = record.meta
|
metadata = record.meta
|
||||||
score = 1 - dis
|
score = 1 - dis
|
||||||
metadata["score"] = score
|
metadata["score"] = score
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if score > score_threshold:
|
if score > score_threshold:
|
||||||
doc = Document(page_content=record.text, metadata=metadata)
|
doc = Document(page_content=record.text, metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
|
@ -144,7 +144,7 @@ class PGVector(BaseVector):
|
||||||
(json.dumps(query_vector),),
|
(json.dumps(query_vector),),
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
for record in cur:
|
for record in cur:
|
||||||
metadata, text, distance = record
|
metadata, text, distance = record
|
||||||
score = 1 - distance
|
score = 1 - distance
|
||||||
|
|
|
@ -333,13 +333,13 @@ class QdrantVector(BaseVector):
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
with_vectors=True,
|
with_vectors=True,
|
||||||
score_threshold=kwargs.get("score_threshold", 0.0),
|
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
for result in results:
|
for result in results:
|
||||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||||
# duplicate check score threshold
|
# duplicate check score threshold
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if result.score > score_threshold:
|
if result.score > score_threshold:
|
||||||
metadata["score"] = result.score
|
metadata["score"] = result.score
|
||||||
doc = Document(
|
doc = Document(
|
||||||
|
|
|
@ -230,7 +230,7 @@ class RelytVector(BaseVector):
|
||||||
# Organize results.
|
# Organize results.
|
||||||
docs = []
|
docs = []
|
||||||
for document, score in results:
|
for document, score in results:
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
if 1 - score > score_threshold:
|
if 1 - score > score_threshold:
|
||||||
docs.append(document)
|
docs.append(document)
|
||||||
return docs
|
return docs
|
||||||
|
|
|
@ -153,7 +153,7 @@ class TencentVector(BaseVector):
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
timeout=self._client_config.timeout,
|
timeout=self._client_config.timeout,
|
||||||
)
|
)
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
return self._get_search_res(res, score_threshold)
|
return self._get_search_res(res, score_threshold)
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
|
|
@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
filter = kwargs.get("filter")
|
filter = kwargs.get("filter")
|
||||||
distance = 1 - score_threshold
|
distance = 1 - score_threshold
|
||||||
|
|
||||||
|
|
|
@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for doc, score in docs_and_scores:
|
for doc, score in docs_and_scores:
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
# check score threshold
|
# check score threshold
|
||||||
if score > score_threshold:
|
if score > score_threshold:
|
||||||
doc.metadata["score"] = score
|
doc.metadata["score"] = score
|
||||||
|
|
Loading…
Reference in New Issue
Block a user