mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: support Chroma vector store (#5015)
This commit is contained in:
parent
3f18369ad2
commit
cdc08a434f
8
.github/workflows/api-tests.yml
vendored
8
.github/workflows/api-tests.yml
vendored
|
@ -58,7 +58,7 @@ jobs:
|
|||
- name: Run Workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
|
@ -67,6 +67,7 @@ jobs:
|
|||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
docker/docker-compose.pgvector.yaml
|
||||
docker/docker-compose.chroma.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
|
@ -75,6 +76,7 @@ jobs:
|
|||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
@ -131,7 +133,7 @@ jobs:
|
|||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
|
@ -140,6 +142,7 @@ jobs:
|
|||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
docker/docker-compose.pgvector.yaml
|
||||
docker/docker-compose.chroma.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
|
@ -148,6 +151,7 @@ jobs:
|
|||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -149,6 +149,7 @@ docker/volumes/qdrant/*
|
|||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
docker/volumes/chroma/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
|
|
@ -119,6 +119,14 @@ TIDB_VECTOR_USER=xxx.root
|
|||
TIDB_VECTOR_PASSWORD=xxxxxx
|
||||
TIDB_VECTOR_DATABASE=dify
|
||||
|
||||
# Chroma configuration
|
||||
CHROMA_HOST=127.0.0.1
|
||||
CHROMA_PORT=8000
|
||||
CHROMA_TENANT=default_tenant
|
||||
CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
|
|
@ -306,6 +306,14 @@ class Config:
|
|||
self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
|
||||
self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
|
||||
|
||||
# chroma settings
|
||||
self.CHROMA_HOST = get_env('CHROMA_HOST')
|
||||
self.CHROMA_PORT = get_env('CHROMA_PORT')
|
||||
self.CHROMA_TENANT = get_env('CHROMA_TENANT')
|
||||
self.CHROMA_DATABASE = get_env('CHROMA_DATABASE')
|
||||
self.CHROMA_AUTH_PROVIDER = get_env('CHROMA_AUTH_PROVIDER')
|
||||
self.CHROMA_AUTH_CREDENTIALS = get_env('CHROMA_AUTH_CREDENTIALS')
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
|
|
|
@ -479,7 +479,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
vector_type = current_app.config['VECTOR_STORE']
|
||||
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
|
@ -501,7 +501,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
|
|
0
api/core/rag/datasource/vdb/chroma/__init__.py
Normal file
0
api/core/rag/datasource/vdb/chroma/__init__.py
Normal file
147
api/core/rag/datasource/vdb/chroma/chroma_vector.py
Normal file
147
api/core/rag/datasource/vdb/chroma/chroma_vector.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class ChromaConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
tenant: str
|
||||
database: str
|
||||
auth_provider: Optional[str] = None
|
||||
auth_credentials: Optional[str] = None
|
||||
|
||||
def to_chroma_params(self):
|
||||
settings = Settings(
|
||||
# auth
|
||||
chroma_client_auth_provider=self.auth_provider,
|
||||
chroma_client_auth_credentials=self.auth_credentials
|
||||
)
|
||||
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'ssl': False,
|
||||
'tenant': self.tenant,
|
||||
'database': self.database,
|
||||
'settings': settings,
|
||||
}
|
||||
|
||||
|
||||
class ChromaVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: ChromaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = chromadb.HttpClient(**self._client_config.to_chroma_params())
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.CHROMA
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
# create collection
|
||||
self.create_collection(self._collection_name)
|
||||
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {'$eq': value}})
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
response = collection.get(ids=[id])
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
|
||||
ids: list[str] = results['ids'][0]
|
||||
documents: list[str] = results['documents'][0]
|
||||
metadatas: dict[str, Any] = results['metadatas'][0]
|
||||
distances: list[float] = results['distances'][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
if distance >= score_threshold:
|
||||
metadata['score'] = distance
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# chroma does not support BM25 full text searching
|
||||
return []
|
||||
|
||||
|
||||
class ChromaVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {
|
||||
"type": VectorType.CHROMA,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
config = current_app.config
|
||||
return ChromaVector(
|
||||
collection_name=collection_name,
|
||||
config=ChromaConfig(
|
||||
host=config.get('CHROMA_HOST'),
|
||||
port=int(config.get('CHROMA_PORT')),
|
||||
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
|
||||
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
|
||||
auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
|
||||
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
|
||||
),
|
||||
)
|
|
@ -52,6 +52,9 @@ class Vector:
|
|||
@staticmethod
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
match vector_type:
|
||||
case VectorType.CHROMA:
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
|
||||
return ChromaVectorFactory
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
return MilvusVectorFactory
|
||||
|
|
|
@ -2,6 +2,7 @@ from enum import Enum
|
|||
|
||||
|
||||
class VectorType(str, Enum):
|
||||
CHROMA = 'chroma'
|
||||
MILVUS = 'milvus'
|
||||
PGVECTOR = 'pgvector'
|
||||
PGVECTO_RS = 'pgvecto-rs'
|
||||
|
|
1249
api/poetry.lock
generated
1249
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
@ -107,7 +107,6 @@ pycryptodome = "3.19.1"
|
|||
python-dotenv = "1.0.0"
|
||||
authlib = "1.2.0"
|
||||
boto3 = "1.28.17"
|
||||
tenacity = "8.2.2"
|
||||
cachetools = "~5.3.0"
|
||||
weaviate-client = "~3.21.0"
|
||||
mailchimp-transactional = "~1.0.50"
|
||||
|
@ -179,6 +178,7 @@ google-cloud-aiplatform = "1.49.0"
|
|||
vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
|
||||
kaleido = "0.2.1"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
||||
chromadb = "~0.5.0"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
|
|
@ -16,7 +16,6 @@ pycryptodome==3.19.1
|
|||
python-dotenv==1.0.0
|
||||
Authlib==1.2.0
|
||||
boto3==1.34.123
|
||||
tenacity==8.2.2
|
||||
cachetools~=5.3.0
|
||||
weaviate-client~=3.21.0
|
||||
mailchimp-transactional~=1.0.50
|
||||
|
@ -85,4 +84,5 @@ pymysql==1.1.1
|
|||
tidb-vector==0.0.9
|
||||
google-cloud-aiplatform==1.49.0
|
||||
vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
|
||||
tencentcloud-sdk-python-hunyuan~=3.0.1158
|
||||
tencentcloud-sdk-python-hunyuan~=3.0.1158
|
||||
chromadb~=0.5.0
|
||||
|
|
0
api/tests/integration_tests/vdb/chroma/__init__.py
Normal file
0
api/tests/integration_tests/vdb/chroma/__init__.py
Normal file
33
api/tests/integration_tests/vdb/chroma/test_chroma.py
Normal file
33
api/tests/integration_tests/vdb/chroma/test_chroma.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import chromadb
|
||||
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class ChromaVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = ChromaVector(
|
||||
collection_name=self.collection_name,
|
||||
config=ChromaConfig(
|
||||
host='localhost',
|
||||
port=8000,
|
||||
tenant=chromadb.DEFAULT_TENANT,
|
||||
database=chromadb.DEFAULT_DATABASE,
|
||||
auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
|
||||
auth_credentials="difyai123456",
|
||||
)
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# chroma dos not support full text searching
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
ChromaVectorTest().run_all_tests()
|
14
docker/docker-compose.chroma.yaml
Normal file
14
docker/docker-compose.chroma.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
version: '3'
|
||||
services:
|
||||
# Chroma vector store.
|
||||
chroma:
|
||||
image: ghcr.io/chroma-core/chroma:0.5.0
|
||||
restart: always
|
||||
volumes:
|
||||
- ./volumes/chroma:/chroma/chroma
|
||||
environment:
|
||||
CHROMA_SERVER_AUTHN_CREDENTIALS: difyai123456
|
||||
CHROMA_SERVER_AUTHN_PROVIDER: chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
IS_PERSISTENT: TRUE
|
||||
ports:
|
||||
- "8000:8000"
|
|
@ -140,6 +140,13 @@ services:
|
|||
TIDB_VECTOR_USER: xxx.root
|
||||
TIDB_VECTOR_PASSWORD: xxxxxx
|
||||
TIDB_VECTOR_DATABASE: dify
|
||||
# Chroma configuration
|
||||
CHROMA_HOST: 127.0.0.1
|
||||
CHROMA_PORT: 8000
|
||||
CHROMA_TENANT: default_tenant
|
||||
CHROMA_DATABASE: default_database
|
||||
CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
|
||||
CHROMA_AUTH_CREDENTIALS: xxxxxx
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
|
@ -301,6 +308,13 @@ services:
|
|||
TIDB_VECTOR_USER: xxx.root
|
||||
TIDB_VECTOR_PASSWORD: xxxxxx
|
||||
TIDB_VECTOR_DATABASE: dify
|
||||
# Chroma configuration
|
||||
CHROMA_HOST: 127.0.0.1
|
||||
CHROMA_PORT: 8000
|
||||
CHROMA_TENANT: default_tenant
|
||||
CHROMA_DATABASE: default_database
|
||||
CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
|
||||
CHROMA_AUTH_CREDENTIALS: xxxxxx
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE: public
|
||||
NOTION_CLIENT_SECRET: you-client-secret
|
||||
|
|
Loading…
Reference in New Issue
Block a user