chore: merge main
2
.github/workflows/build-push.yml
vendored
|
@ -125,7 +125,7 @@ jobs:
|
|||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
|
|
@ -162,6 +162,8 @@ PGVECTOR_PORT=5433
|
|||
PGVECTOR_USER=postgres
|
||||
PGVECTOR_PASSWORD=postgres
|
||||
PGVECTOR_DATABASE=postgres
|
||||
PGVECTOR_MIN_CONNECTION=1
|
||||
PGVECTOR_MAX_CONNECTION=5
|
||||
|
||||
# Tidb Vector configuration
|
||||
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||
|
|
|
@ -33,3 +33,13 @@ class PGVectorConfig(BaseSettings):
|
|||
description="Name of the PostgreSQL database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Min connection of the PostgreSQL database",
|
||||
default=1,
|
||||
)
|
||||
|
||||
PGVECTOR_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Max connection of the PostgreSQL database",
|
||||
default=5,
|
||||
)
|
||||
|
|
|
@ -37,7 +37,16 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
|
|||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
|
||||
from .datasets import (
|
||||
data_source,
|
||||
datasets,
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
external,
|
||||
file,
|
||||
hit_testing,
|
||||
website,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
|
|
@ -49,7 +49,7 @@ class DatasetListApi(Resource):
|
|||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
provider = request.args.get("provider", default="vendor")
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
|
@ -57,7 +57,7 @@ class DatasetListApi(Resource):
|
|||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
|
@ -110,6 +110,26 @@ class DatasetListApi(Resource):
|
|||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
choices=Dataset.PROVIDER_LIST,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
|
@ -123,6 +143,9 @@ class DatasetListApi(Resource):
|
|||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
@ -211,6 +234,33 @@ class DatasetApi(Resource):
|
|||
)
|
||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
|
||||
parser.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
|
@ -563,10 +613,10 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
@ -577,6 +627,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
239
api/controllers/console/datasets/external.py
Normal file
|
@ -0,0 +1,239 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class ExternalApiTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
page, limit, current_user.current_tenant_id, search
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||
"has_more": len(external_knowledge_apis) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return external_knowledge_api.to_dict(), 201
|
||||
|
||||
|
||||
class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
|
||||
if external_knowledge_api is None:
|
||||
raise NotFound("API template not found.")
|
||||
|
||||
return external_knowledge_api.to_dict(), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, external_knowledge_api_id):
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=args,
|
||||
)
|
||||
|
||||
return external_knowledge_api.to_dict(), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, external_knowledge_api_id):
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
external_knowledge_api_id
|
||||
)
|
||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||
|
||||
|
||||
class ExternalDatasetCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
dataset = ExternalDatasetService.create_external_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
args=args,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
|
||||
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
|
||||
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
|
||||
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
|
||||
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
|
@ -47,6 +47,7 @@ class HitTestingApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
@ -57,6 +58,7 @@ class HitTestingApi(Resource):
|
|||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
|
|
|
@ -14,7 +14,9 @@ class WebsiteCrawlApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json"
|
||||
)
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
@ -33,7 +35,7 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
|
|
|
@ -38,11 +38,52 @@ class VersionApi(Resource):
|
|||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
result["can_auto_update"] = content["canAutoUpdate"]
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
result["can_auto_update"] = content["canAutoUpdate"]
|
||||
return result
|
||||
|
||||
|
||||
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
def parse_version(version: str) -> tuple:
|
||||
# Split version into parts and pre-release suffix if any
|
||||
parts = version.split("-")
|
||||
version_parts = parts[0].split(".")
|
||||
pre_release = parts[1] if len(parts) > 1 else None
|
||||
|
||||
# Validate version format
|
||||
if len(version_parts) != 3:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
try:
|
||||
# Convert version parts to integers
|
||||
major, minor, patch = map(int, version_parts)
|
||||
return (major, minor, patch, pre_release)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
latest = parse_version(latest_version)
|
||||
current = parse_version(current_version)
|
||||
|
||||
# Compare major, minor, and patch versions
|
||||
for latest_part, current_part in zip(latest[:3], current[:3]):
|
||||
if latest_part > current_part:
|
||||
return True
|
||||
elif latest_part < current_part:
|
||||
return False
|
||||
|
||||
# If versions are equal, check pre-release suffixes
|
||||
if latest[3] is None and current[3] is not None:
|
||||
return True
|
||||
elif latest[3] is not None and current[3] is None:
|
||||
return False
|
||||
elif latest[3] is not None and current[3] is not None:
|
||||
# Simple string comparison for pre-release versions
|
||||
return latest[3] > current[3]
|
||||
|
||||
return False
|
||||
|
||||
|
||||
api.add_resource(VersionApi, "/version")
|
||||
|
|
|
@ -28,11 +28,11 @@ class DatasetListApi(DatasetApiResource):
|
|||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
provider = request.args.get("provider", default="vendor")
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
@ -82,6 +82,26 @@ class DatasetListApi(DatasetApiResource):
|
|||
required=False,
|
||||
nullable=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="_validate_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
@ -91,6 +111,9 @@ class DatasetListApi(DatasetApiResource):
|
|||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
|
|
@ -231,7 +231,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
|
|
|
@ -212,7 +212,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
class VariableError(Exception):
|
||||
class VariableError(ValueError):
|
||||
pass
|
||||
|
|
|
@ -248,7 +248,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
if publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
|
|
|
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
|
|||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get("position"),
|
||||
position=item.get("position") or 0,
|
||||
dataset_id=item.get("dataset_id"),
|
||||
dataset_name=item.get("dataset_name"),
|
||||
document_id=item.get("document_id"),
|
||||
|
|
|
@ -119,7 +119,7 @@ class ProviderConfiguration(BaseModel):
|
|||
credentials = model_configuration.credentials
|
||||
break
|
||||
|
||||
if self.custom_configuration.provider:
|
||||
if not credentials and self.custom_configuration.provider:
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
|
||||
return credentials
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
|
@ -13,7 +14,7 @@ _TEXT_COLOR_MAPPING = {
|
|||
}
|
||||
|
||||
|
||||
class Callback:
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base class for callbacks.
|
||||
Only for LLM.
|
||||
|
@ -21,6 +22,7 @@ class Callback:
|
|||
|
||||
raise_error: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
|
@ -48,6 +50,7 @@ class Callback:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
|
@ -77,6 +80,7 @@ class Callback:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
|
@ -106,6 +110,7 @@ class Callback:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
## Custom Integration of Pre-defined Models
|
||||
|
||||
### Introduction
|
||||
|
||||
After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
|
||||
|
||||
It is important to note that for custom models, each model connection requires a complete vendor credential.
|
||||
|
||||
Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
|
||||
|
||||
![](images/index/image-3.png)
|
||||
|
||||
As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
|
||||
|
||||
### Writing the Vendor YAML
|
||||
|
||||
First, we need to identify the types of models supported by the vendor we are integrating.
|
||||
|
||||
Currently supported model types are as follows:
|
||||
|
||||
- `llm` Text Generation Models
|
||||
|
||||
- `text_embedding` Text Embedding Models
|
||||
|
||||
- `rerank` Rerank Models
|
||||
|
||||
- `speech2text` Speech-to-Text
|
||||
|
||||
- `tts` Text-to-Speech
|
||||
|
||||
- `moderation` Moderation
|
||||
|
||||
Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
|
||||
|
||||
```yaml
|
||||
provider: xinference #Define the vendor identifier
|
||||
label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
|
||||
en_US: Xorbits Inference
|
||||
icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
|
||||
en_US: icon_s_en.svg
|
||||
icon_large: # Large icon
|
||||
en_US: icon_l_en.svg
|
||||
help: # Help information
|
||||
title:
|
||||
en_US: How to deploy Xinference
|
||||
zh_Hans: 如何部署 Xinference
|
||||
url:
|
||||
en_US: https://github.com/xorbitsai/inference
|
||||
supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
```
|
||||
|
||||
|
||||
Then, we need to determine what credentials are required to define a model in Xinference.
|
||||
|
||||
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
|
||||
|
||||
```yaml
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: model_type
|
||||
type: select
|
||||
label:
|
||||
en_US: Model type
|
||||
zh_Hans: 模型类型
|
||||
required: true
|
||||
options:
|
||||
- value: text-generation
|
||||
label:
|
||||
en_US: Language Model
|
||||
zh_Hans: 语言模型
|
||||
- value: embeddings
|
||||
label:
|
||||
en_US: Text Embedding
|
||||
- value: reranking
|
||||
label:
|
||||
en_US: Rerank
|
||||
```
|
||||
|
||||
- Next, each model has its own model_name, so we need to define that here:
|
||||
|
||||
```yaml
|
||||
- variable: model_name
|
||||
type: text-input
|
||||
label:
|
||||
en_US: Model name
|
||||
zh_Hans: 模型名称
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 填写模型名称
|
||||
en_US: Input model name
|
||||
```
|
||||
|
||||
- Specify the Xinference local deployment address:
|
||||
|
||||
```yaml
|
||||
- variable: server_url
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||
```
|
||||
|
||||
- Each model has a unique model_uid, so we also need to define that here:
|
||||
|
||||
```yaml
|
||||
- variable: model_uid
|
||||
label:
|
||||
zh_Hans: 模型UID
|
||||
en_US: Model uid
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的Model UID
|
||||
en_US: Enter the model uid
|
||||
```
|
||||
|
||||
Now, we have completed the basic definition of the vendor.
|
||||
|
||||
### Writing the Model Code
|
||||
|
||||
Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
|
||||
|
||||
In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Invocation
|
||||
|
||||
Implement the core method for LLM invocation, supporting both stream and synchronous responses.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool usage
|
||||
:param stop: stop words
|
||||
:param stream: is the response a stream
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- Pre-compute Input Tokens
|
||||
|
||||
If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool usage
|
||||
:return: token count
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
|
||||
|
||||
- Model Credentials Validation
|
||||
|
||||
Similar to vendor credentials validation, this method validates individual model credentials.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: None
|
||||
"""
|
||||
```
|
||||
|
||||
- Model Parameter Schema
|
||||
|
||||
Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
|
||||
|
||||
For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
|
||||
|
||||
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature', type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度', en_US='Temperature'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p', type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P', en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens', type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# if model is A, add top_k to rules
|
||||
if model == 'A':
|
||||
rules.append(
|
||||
ParameterRule(
|
||||
name='top_k', type=ParameterType.INT,
|
||||
use_template='top_k',
|
||||
min=1,
|
||||
default=50,
|
||||
label=I18nObject(
|
||||
zh_Hans='Top K', en_US='Top K'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
some NOT IMPORTANT code here
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=model_type,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: ModelType.LLM,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
```
|
||||
|
||||
- Exception Error Mapping
|
||||
|
||||
When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Connection error during invocation
|
||||
- `InvokeServerUnavailableError` Service provider unavailable
|
||||
- `InvokeRateLimitError` Rate limit reached
|
||||
- `InvokeAuthorizationError` Authorization failure
|
||||
- `InvokeBadRequestError` Invalid request parameters
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
BIN
api/core/model_runtime/docs/en_US/images/index/image-1.png
Normal file
After Width: | Height: | Size: 230 KiB |
BIN
api/core/model_runtime/docs/en_US/images/index/image-2.png
Normal file
After Width: | Height: | Size: 205 KiB |
BIN
api/core/model_runtime/docs/en_US/images/index/image-3.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
api/core/model_runtime/docs/en_US/images/index/image.png
Normal file
After Width: | Height: | Size: 262 KiB |
173
api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
Normal file
|
@ -0,0 +1,173 @@
|
|||
## Predefined Model Integration
|
||||
|
||||
After completing the vendor integration, the next step is to integrate the models from the vendor.
|
||||
|
||||
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
|
||||
|
||||
Currently supported model types are:
|
||||
|
||||
- `llm` Text Generation Model
|
||||
- `text_embedding` Text Embedding Model
|
||||
- `rerank` Rerank Model
|
||||
- `speech2text` Speech-to-Text
|
||||
- `tts` Text-to-Speech
|
||||
- `moderation` Moderation
|
||||
|
||||
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
|
||||
|
||||
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
|
||||
|
||||
### Prepare Model YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # Model identifier
|
||||
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
|
||||
# This can also be omitted, in which case the model identifier will be used as the label
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # Model type, claude-2.1 is an LLM
|
||||
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
|
||||
- agent-thought
|
||||
model_properties: # Model properties
|
||||
mode: chat # LLM mode, complete for text completion models, chat for conversation models
|
||||
context_size: 200000 # Maximum context size
|
||||
parameter_rules: # Parameter rules for the model call; only LLM requires this
|
||||
- name: temperature # Parameter variable name
|
||||
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
# Additional configuration parameters will override the default configuration if set
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # Display name of the parameter
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int # Parameter type, supports float/int/string/boolean
|
||||
help: # Help information, describing the parameter's function
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # Whether the parameter is mandatory; can be omitted
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # Default value of the parameter
|
||||
min: 1 # Minimum value of the parameter, applicable to float/int only
|
||||
max: 4096 # Maximum value of the parameter, applicable to float/int only
|
||||
pricing: # Pricing information
|
||||
input: '8.00' # Input unit price, i.e., prompt price
|
||||
output: '24.00' # Output unit price, i.e., response content price
|
||||
unit: '0.000001' # Price unit, meaning the above prices are per 100K
|
||||
currency: USD # Price currency
|
||||
```
|
||||
|
||||
It is recommended to prepare all model configurations before starting the implementation of the model code.
|
||||
|
||||
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
|
||||
|
||||
### Implement the Model Call Code
|
||||
|
||||
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
|
||||
|
||||
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Call
|
||||
|
||||
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- Pre-compute Input Tokens
|
||||
|
||||
If the model does not provide an interface to precompute tokens, return 0 directly.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Validate Model Credentials
|
||||
|
||||
Similar to vendor credential validation, but specific to a single model.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Map Invoke Errors
|
||||
|
||||
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Connection error
|
||||
|
||||
- `InvokeServerUnavailableError` Service provider unavailable
|
||||
- `InvokeRateLimitError` Rate limit reached
|
||||
- `InvokeAuthorizationError` Authorization failed
|
||||
- `InvokeBadRequestError` Parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
|
@ -58,7 +58,7 @@ provider_credential_schema: # Provider credential rules, as Anthropic only supp
|
|||
en_US: Enter your API URL
|
||||
```
|
||||
|
||||
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider).
|
||||
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider).
|
||||
|
||||
### Implementing Provider Code
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ model_credential_schema:
|
|||
en_US: Enter your API Base
|
||||
```
|
||||
|
||||
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。
|
||||
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
|
||||
|
||||
#### 实现供应商代码
|
||||
|
||||
|
|
|
@ -40,3 +40,4 @@
|
|||
- fireworks
|
||||
- mixedbread
|
||||
- nomic
|
||||
- voyage
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
- anthropic.claude-v2:1
|
||||
- anthropic.claude-3-sonnet-v1:0
|
||||
- anthropic.claude-3-haiku-v1:0
|
||||
- ai21.jamba-1-5-large-v1:0
|
||||
- ai21.jamba-1-5-mini-v1:0
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- cohere.command-r-plus-v1.0
|
||||
|
@ -15,6 +17,10 @@
|
|||
- meta.llama3-1-405b-instruct-v1:0
|
||||
- meta.llama3-8b-instruct-v1:0
|
||||
- meta.llama3-70b-instruct-v1:0
|
||||
- us.meta.llama3-2-1b-instruct-v1:0
|
||||
- us.meta.llama3-2-3b-instruct-v1:0
|
||||
- us.meta.llama3-2-11b-instruct-v1:0
|
||||
- us.meta.llama3-2-90b-instruct-v1:0
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2407-v1:0
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
model: ai21.jamba-1-5-large-v1:0
|
||||
label:
|
||||
en_US: Jamba 1.5 Large
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.002'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,26 @@
|
|||
model: ai21.jamba-1-5-mini-v1:0
|
||||
label:
|
||||
en_US: Jamba 1.5 Mini
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.0002'
|
||||
output: '0.0004'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -63,6 +63,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "us.meta.llama3-2", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
|
||||
{"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False},
|
||||
|
@ -70,6 +71,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
{"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
|
||||
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
model: us.meta.llama3-2-11b-instruct-v1:0
|
||||
label:
|
||||
en_US: US Meta Llama 3.2 11B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.00035'
|
||||
output: '0.00035'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,26 @@
|
|||
model: us.meta.llama3-2-1b-instruct-v1:0
|
||||
label:
|
||||
en_US: US Meta Llama 3.2 1B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.0001'
|
||||
output: '0.0001'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,26 @@
|
|||
model: us.meta.llama3-2-3b-instruct-v1:0
|
||||
label:
|
||||
en_US: US Meta Llama 3.2 3B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.00015'
|
||||
output: '0.00015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,31 @@
|
|||
model: us.meta.llama3-2-90b-instruct-v1:0
|
||||
label:
|
||||
en_US: US Meta Llama 3.2 90B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.9
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.002'
|
||||
output: '0.002'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -1,24 +1,23 @@
|
|||
- Qwen2.5-72B-Instruct
|
||||
- Qwen2.5-7B-Instruct
|
||||
- Qwen2-72B-Instruct
|
||||
- Qwen2-72B-Instruct-AWQ-int4
|
||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||
- Qwen2-7B-Instruct
|
||||
- Qwen2-7B
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-7B
|
||||
- Qwen-14B-Chat-Int4
|
||||
- Yi-Coder-1.5B-Chat
|
||||
- Yi-Coder-9B-Chat
|
||||
- Qwen2-72B-Instruct-AWQ-int4
|
||||
- Yi-1_5-9B-Chat-16K
|
||||
- Qwen2-7B-Instruct
|
||||
- Reflection-Llama-3.1-70B
|
||||
- Qwen2-72B-Instruct
|
||||
- Meta-Llama-3.1-8B-Instruct
|
||||
|
||||
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||
- chatglm3-6b
|
||||
- Meta-Llama-3-8B-Instruct
|
||||
- Llama3-Chinese_v2
|
||||
- deepseek-v2-lite-chat
|
||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||
- Qwen2-7B
|
||||
- Qwen-14B-Chat-Int4
|
||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-7B
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- deepseek-v2-chat
|
||||
- chatglm3-6b
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
- gte-Qwen2-7B-instruct
|
||||
- BAAI/bge-large-en-v1.5
|
||||
- BAAI/bge-large-zh-v1.5
|
||||
- BAAI/bge-m3
|
|
@ -2,3 +2,4 @@ model: gte-Qwen2-7B-instruct
|
|||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
deprecated: true
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
- Qwen/Qwen2.5-72B-Instruct
|
||||
- Qwen/Qwen2.5-Math-72B-Instruct
|
||||
- Qwen/Qwen2.5-32B-Instruct
|
||||
- Qwen/Qwen2.5-14B-Instruct
|
||||
- Qwen/Qwen2.5-7B-Instruct
|
||||
- Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
- deepseek-ai/DeepSeek-V2.5
|
||||
- Qwen/Qwen2.5-Math-72B-Instruct
|
||||
- Qwen/Qwen2-72B-Instruct
|
||||
- Qwen/Qwen2-57B-A14B-Instruct
|
||||
- Qwen/Qwen2-7B-Instruct
|
||||
- Qwen/Qwen2-1.5B-Instruct
|
||||
- deepseek-ai/DeepSeek-V2.5
|
||||
- deepseek-ai/DeepSeek-V2-Chat
|
||||
- deepseek-ai/DeepSeek-Coder-V2-Instruct
|
||||
- THUDM/glm-4-9b-chat
|
||||
- THUDM/chatglm3-6b
|
||||
- 01-ai/Yi-1.5-34B-Chat-16K
|
||||
- 01-ai/Yi-1.5-9B-Chat-16K
|
||||
- 01-ai/Yi-1.5-6B-Chat
|
||||
|
@ -26,13 +25,4 @@
|
|||
- google/gemma-2-27b-it
|
||||
- google/gemma-2-9b-it
|
||||
- mistralai/Mistral-7B-Instruct-v0.2
|
||||
- Pro/Qwen/Qwen2-7B-Instruct
|
||||
- Pro/Qwen/Qwen2-1.5B-Instruct
|
||||
- Pro/THUDM/glm-4-9b-chat
|
||||
- Pro/THUDM/chatglm3-6b
|
||||
- Pro/01-ai/Yi-1.5-9B-Chat-16K
|
||||
- Pro/01-ai/Yi-1.5-6B-Chat
|
||||
- Pro/internlm/internlm2_5-7b-chat
|
||||
- Pro/meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
- Pro/meta-llama/Meta-Llama-3-8B-Instruct
|
||||
- Pro/google/gemma-2-9b-it
|
||||
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
model: internlm/internlm2_5-20b-chat
|
||||
label:
|
||||
en_US: internlm/internlm2_5-20b-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '1'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -0,0 +1,74 @@
|
|||
model: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
label:
|
||||
en_US: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -0,0 +1,74 @@
|
|||
model: Qwen/Qwen2.5-Math-72B-Instruct
|
||||
label:
|
||||
en_US: Qwen/Qwen2.5-Math-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -1,7 +1,7 @@
|
|||
# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models
|
||||
model: qwen2.5-7b-instruct
|
||||
model: qwen2.5-coder-7b-instruct
|
||||
label:
|
||||
en_US: qwen2.5-7b-instruct
|
||||
en_US: qwen2.5-coder-7b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
<svg version="1.0" xmlns="http://www.w3.org/2000/svg" width="100.000000pt" height="19.000000pt" viewBox="0 0 300.000000 57.000000" preserveAspectRatio="xMidYMid meet"><g transform="translate(0.000000,57.000000) scale(0.100000,-0.100000)" fill="#000000" stroke="none"><path d="M2505 368 c-38 -84 -86 -188 -106 -230 l-38 -78 27 0 c24 0 30 7 55
|
||||
75 l28 75 100 0 100 0 25 -55 c13 -31 24 -64 24 -75 0 -17 7 -20 44 -20 l43 0
|
||||
-37 73 c-20 39 -68 143 -106 229 -38 87 -74 158 -80 158 -5 0 -41 -69 -79
|
||||
-152z m110 -30 c22 -51 41 -95 42 -98 2 -3 -36 -6 -83 -7 -76 -1 -85 0 -81 15
|
||||
12 40 72 182 77 182 3 0 24 -41 45 -92z"/><path d="M63 493 c19 -61 197 -438 209 -440 10 -2 147 282 216 449 2 4 -10 8
|
||||
-27 8 -23 0 -31 -5 -31 -17 0 -16 -142 -365 -146 -360 -8 11 -144 329 -149
|
||||
350 -6 23 -12 27 -42 27 -29 0 -34 -3 -30 -17z"/><path d="M2855 285 l0 -225 30 0 30 0 0 225 0 225 -30 0 -30 0 0 -225z"/><path d="M588 380 c-55 -30 -82 -74 -86 -145 -3 -50 0 -66 20 -95 39 -58 82
|
||||
-80 153 -80 68 0 110 21 149 73 32 43 30 150 -3 196 -47 66 -158 90 -233 51z
|
||||
m133 -16 c59 -30 89 -156 54 -224 -45 -87 -162 -78 -201 16 -18 44 -18 128 1
|
||||
164 28 55 90 73 146 44z"/><path d="M935 303 l76 -98 -7 -72 -6 -73 33 0 34 0 -3 78 -4 77 71 93 c65 85
|
||||
68 92 46 92 -15 0 -29 -9 -36 -22 -18 -33 -90 -128 -98 -128 -6 1 -67 85 -88
|
||||
122 -8 15 -24 23 -53 25 l-41 4 76 -98z"/><path d="M1257 230 c-82 -169 -83 -170 -57 -170 17 0 27 6 27 15 0 8 7 31 17
|
||||
52 l17 38 79 0 78 1 16 -34 c9 -18 16 -42 16 -52 0 -17 7 -20 41 -20 22 0 39
|
||||
3 37 8 -2 4 -39 80 -83 170 -43 89 -84 162 -92 162 -7 0 -50 -76 -96 -170z
|
||||
m90 -38 c-33 -2 -61 -1 -63 1 -2 2 10 34 26 71 l31 68 33 -68 33 -69 -60 -3z"/><path d="M1665 386 c-37 -16 -84 -63 -97 -96 -13 -35 -12 -104 2 -132 49 -94
|
||||
182 -134 280 -83 24 12 29 22 32 64 3 49 3 49 -30 53 l-33 4 3 -45 c4 -61 -5
|
||||
-71 -60 -71 -93 0 -142 57 -142 164 0 44 5 60 25 85 47 55 136 65 184 20 30
|
||||
-28 35 -20 11 19 -19 31 -22 32 -82 32 -35 -1 -76 -7 -93 -14z"/><path d="M1955 230 l0 -170 91 0 c76 0 93 3 98 16 4 9 5 18 4 20 -2 1 -31 -1
|
||||
-66 -5 -34 -4 -64 -5 -67 -3 -3 3 -5 36 -5 73 l0 68 55 -6 c49 -5 55 -4 55 13
|
||||
0 17 -6 19 -55 16 l-55 -4 0 61 0 61 64 0 c48 0 65 4 70 15 4 13 -10 15 -92
|
||||
15 l-97 0 0 -170z"/></g></svg>
|
After Width: | Height: | Size: 2.2 KiB |
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="64px" height="64px" viewBox="0 0 64 64" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>voyage</title>
|
||||
<g id="voyage" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<rect id="矩形" fill="#333333" x="0" y="0" width="64" height="64" rx="12"></rect>
|
||||
<path d="M12.1128004,51.4376727 C13.8950799,45.8316747 30.5922254,11.1847688 31.7178757,11.0009656 C32.6559176,10.8171624 45.5070913,36.9172188 51.9795803,52.2647871 C52.1671887,52.6323936 51.0415384,53 49.4468672,53 C47.2893709,53 46.5389374,52.540492 46.5389374,51.4376727 C46.5389374,49.967247 33.2187427,17.8935861 32.8435259,18.3530942 C32.0930924,19.3640118 19.3357228,48.5887229 18.8667019,50.5186566 C18.3038768,52.6323936 17.7410516,53 14.926926,53 C12.2066045,53 11.7375836,52.7242952 12.1128004,51.4376727 Z" id="路径" fill="#FFFFFF" transform="translate(32, 32) scale(1, -1) translate(-32, -32)"></path>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 1.0 KiB |
|
@ -0,0 +1,4 @@
|
|||
model: rerank-1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8000
|
|
@ -0,0 +1,4 @@
|
|||
model: rerank-lite-1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 4000
|
123
api/core/model_runtime/model_providers/voyage/rerank/rerank.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class VoyageRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Voyage rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
base_url + "/rerank",
|
||||
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results["data"]:
|
||||
rerank_document = RerankDocument(
|
||||
index=result["index"],
|
||||
text=result["document"],
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
|
||||
)
|
||||
|
||||
return entity
|
|
@ -0,0 +1,172 @@
|
|||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class VoyageTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Voyage text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "https://api.voyageai.com/v1"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["api_key"]
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
base_url = credentials.get("base_url", self.api_base)
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
url = base_url + "/embeddings"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
voyage_input_type = "null"
|
||||
if input_type is not None:
|
||||
voyage_input_type = input_type.value
|
||||
data = {"model": model, "input": texts, "input_type": voyage_input_type}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
raise InvokeRateLimitError(msg)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
|
@ -0,0 +1,8 @@
|
|||
model: voyage-3-lite
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
pricing:
|
||||
input: '0.00002'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,8 @@
|
|||
model: voyage-3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
pricing:
|
||||
input: '0.00006'
|
||||
unit: '0.001'
|
||||
currency: USD
|
28
api/core/model_runtime/model_providers/voyage/voyage.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VoyageProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
|
||||
|
||||
# Use `voyage-3` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(model="voyage-3", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
31
api/core/model_runtime/model_providers/voyage/voyage.yaml
Normal file
|
@ -0,0 +1,31 @@
|
|||
provider: voyage
|
||||
label:
|
||||
en_US: Voyage
|
||||
description:
|
||||
en_US: Embedding and Rerank Model Supported
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#EFFDFD"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Voyage AI
|
||||
zh_Hans: 从 Voyage 获取 API Key
|
||||
url:
|
||||
en_US: https://dash.voyageai.com/
|
||||
supported_model_types:
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
|
@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
from core.model_runtime.model_providers.xinference.xinference_helper import (
|
||||
XinferenceHelper,
|
||||
XinferenceModelExtraParameter,
|
||||
validate_model_uid,
|
||||
)
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
|
@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||
}
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
|
|
|
@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
|
||||
|
||||
|
||||
class XinferenceRerankModel(RerankModel):
|
||||
|
@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
|
|||
)
|
||||
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
|
|||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
|
|
|
@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
|
||||
|
||||
|
||||
class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
|
@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
|||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
|
|
|
@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
|
||||
|
||||
|
||||
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
server_url = credentials["server_url"]
|
||||
|
|
|
@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
|
||||
|
||||
|
||||
class XinferenceText2SpeechModel(TTSModel):
|
||||
|
@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
|
|
|
@ -132,3 +132,16 @@ class XinferenceHelper:
|
|||
context_length=context_length,
|
||||
model_family=model_family,
|
||||
)
|
||||
|
||||
|
||||
def validate_model_uid(credentials: dict) -> bool:
|
||||
"""
|
||||
Validate the model_uid within the credentials dictionary to ensure it does not
|
||||
contain forbidden characters ("/", "?", "#").
|
||||
|
||||
param credentials: model credentials
|
||||
:return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
|
||||
"""
|
||||
forbidden_characters = ["/", "?", "#"]
|
||||
model_uid = credentials.get("model_uid", "")
|
||||
return not any(char in forbidden_characters for char in model_uid)
|
||||
|
|
|
@ -48,7 +48,7 @@ from ._utils import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
|
||||
from pydantic_core.core_schema import ModelField
|
||||
|
||||
__all__ = ["BaseModel", "GenericModel"]
|
||||
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
|
||||
|
|
|
@ -248,7 +248,7 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
|||
@functools.wraps(func)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
given_params: set[str] = set()
|
||||
for i, _ in enumerate(args):
|
||||
for i in range(len(args)):
|
||||
try:
|
||||
given_params.add(positional[i])
|
||||
except IndexError:
|
||||
|
|
|
@ -18,8 +18,12 @@ class KeywordsModeration(Moderation):
|
|||
if not config.get("keywords"):
|
||||
raise ValueError("keywords is required")
|
||||
|
||||
if len(config.get("keywords")) > 1000:
|
||||
raise ValueError("keywords length must be less than 1000")
|
||||
if len(config.get("keywords")) > 10000:
|
||||
raise ValueError("keywords length must be less than 10000")
|
||||
|
||||
keywords_row_len = config["keywords"].split("\n")
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
|
|
|
@ -45,7 +45,7 @@ class Jieba(BaseKeyword):
|
|||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get("keywords_list", None)
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
|
|
|
@ -10,6 +10,7 @@ from core.rag.rerank.constants.rerank_mode import RerankMode
|
|||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
|
@ -34,6 +35,9 @@ class RetrievalService:
|
|||
weights: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
|
@ -108,6 +112,16 @@ class RetrievalService:
|
|||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
|
|
|
@ -23,6 +23,8 @@ class PGVectorConfig(BaseModel):
|
|||
user: str
|
||||
password: str
|
||||
database: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -37,6 +39,12 @@ class PGVectorConfig(BaseModel):
|
|||
raise ValueError("config PGVECTOR_PASSWORD is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config PGVECTOR_DATABASE is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config PGVECTOR_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config PGVECTOR_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config PGVECTOR_MIN_CONNECTION should less than PGVECTOR_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
|
@ -61,8 +69,8 @@ class PGVector(BaseVector):
|
|||
|
||||
def _create_connection_pool(self, config: PGVectorConfig):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
1,
|
||||
5,
|
||||
config.min_connection,
|
||||
config.max_connection,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
user=config.user,
|
||||
|
@ -158,7 +166,7 @@ class PGVector(BaseVector):
|
|||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
|
@ -213,5 +221,7 @@ class PGVectorFactory(AbstractVectorFactory):
|
|||
user=dify_config.PGVECTOR_USER,
|
||||
password=dify_config.PGVECTOR_PASSWORD,
|
||||
database=dify_config.PGVECTOR_DATABASE,
|
||||
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
|
||||
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -56,7 +56,7 @@ class TencentVector(BaseVector):
|
|||
return self._client.create_database(database_name=self._client_config.database)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "tencent"
|
||||
return VectorType.TENCENT
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
|
10
api/core/rag/entities/context_entities.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DocumentContext(BaseModel):
|
||||
"""
|
||||
Model class for document context.
|
||||
"""
|
||||
|
||||
content: str
|
||||
score: float
|
|
@ -12,6 +12,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
|
@ -171,6 +172,15 @@ class ExtractProcessor:
|
|||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "jinareader":
|
||||
extractor = JinaReaderWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
|
||||
else:
|
||||
|
|
35
api/core/rag/extractor/jina_reader_extractor.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class JinaReaderWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(
|
||||
page_content=crawl_data.get("content", ""),
|
||||
metadata={
|
||||
"source_url": crawl_data.get("url"),
|
||||
"description": crawl_data.get("description"),
|
||||
"title": crawl_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
|
@ -17,6 +17,8 @@ class Document(BaseModel):
|
|||
"""
|
||||
metadata: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
provider: Optional[str] = "dify"
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Abstract base class for document transformation systems.
|
||||
|
|
|
@ -28,11 +28,16 @@ class RerankModelRunner:
|
|||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
dify_documents = [item for item in documents if item.provider == "dify"]
|
||||
external_documents = [item for item in documents if item.provider == "external"]
|
||||
for document in dify_documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
for document in external_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
|
@ -46,14 +51,10 @@ class RerankModelRunner:
|
|||
# format document
|
||||
rerank_document = Document(
|
||||
page_content=result.text,
|
||||
metadata={
|
||||
"doc_id": documents[result.index].metadata["doc_id"],
|
||||
"doc_hash": documents[result.index].metadata["doc_hash"],
|
||||
"document_id": documents[result.index].metadata["document_id"],
|
||||
"dataset_id": documents[result.index].metadata["dataset_id"],
|
||||
"score": result.score,
|
||||
},
|
||||
metadata=documents[result.index].metadata,
|
||||
provider=documents[result.index].provider,
|
||||
)
|
||||
rerank_document.metadata["score"] = result.score
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return rerank_documents
|
||||
|
|
|
@ -20,6 +20,7 @@ from core.ops.utils import measure_time
|
|||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
|
@ -30,6 +31,7 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
|
|||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
|
@ -110,7 +112,7 @@ class DatasetRetrieval:
|
|||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if dataset and dataset.available_document_count == 0:
|
||||
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
|
@ -146,69 +148,93 @@ class DatasetRetrieval:
|
|||
message_id,
|
||||
)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
|
||||
source = {
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": item.metadata.get("score"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
document_score_list = {}
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
for item in dify_documents:
|
||||
if item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if show_retrieve_source:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
}
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
)
|
||||
)
|
||||
if show_retrieve_source:
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
}
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
if hit_callback:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
|
||||
def single_retrieve(
|
||||
|
@ -256,36 +282,58 @@ class DatasetRetrieval:
|
|||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = "keyword_search"
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
reranking_model = (
|
||||
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
|
||||
)
|
||||
# get score threshold
|
||||
score_threshold = 0.0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = Document(
|
||||
page_content=external_document.get("content"),
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = "keyword_search"
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
reranking_model = (
|
||||
retrieval_model_config["reranking_model"]
|
||||
if retrieval_model_config["reranking_enable"]
|
||||
else None
|
||||
)
|
||||
# get score threshold
|
||||
score_threshold = 0.0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
if results:
|
||||
|
@ -356,7 +404,8 @@ class DatasetRetrieval:
|
|||
self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
|
||||
) -> None:
|
||||
"""Handle retrieval end."""
|
||||
for document in documents:
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
@ -409,35 +458,54 @@ class DatasetRetrieval:
|
|||
if not dataset:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
|
||||
if dataset.provider == "external":
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
for external_document in external_documents:
|
||||
document = Document(
|
||||
page_content=external_document.get("content"),
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
all_documents.append(document)
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
all_documents.extend(documents)
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
def to_dataset_retriever_tool(
|
||||
self,
|
||||
|
|
|
@ -34,5 +34,9 @@
|
|||
- feishu_base
|
||||
- feishu_document
|
||||
- feishu_message
|
||||
- feishu_wiki
|
||||
- feishu_task
|
||||
- feishu_calendar
|
||||
- feishu_spreadsheet
|
||||
- slack
|
||||
- tianditu
|
||||
|
|
BIN
api/core/tools/provider/builtin/feishu_calendar/_assets/icon.png
Normal file
After Width: | Height: | Size: 5.4 KiB |
|
@ -0,0 +1,7 @@
|
|||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.utils.feishu_api_utils import auth
|
||||
|
||||
|
||||
class FeishuCalendarProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
auth(credentials)
|
|
@ -0,0 +1,36 @@
|
|||
identity:
|
||||
author: Doug Lea
|
||||
name: feishu_calendar
|
||||
label:
|
||||
en_US: Feishu Calendar
|
||||
zh_Hans: 飞书日历
|
||||
description:
|
||||
en_US: |
|
||||
Feishu calendar, requires the following permissions: calendar:calendar:read、calendar:calendar、contact:user.id:readonly.
|
||||
zh_Hans: |
|
||||
飞书日历,需要开通以下权限: calendar:calendar:read、calendar:calendar、contact:user.id:readonly。
|
||||
icon: icon.png
|
||||
tags:
|
||||
- social
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
app_id:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: APP ID
|
||||
placeholder:
|
||||
en_US: Please input your feishu app id
|
||||
zh_Hans: 请输入你的飞书 app id
|
||||
help:
|
||||
en_US: Get your app_id and app_secret from Feishu
|
||||
zh_Hans: 从飞书获取您的 app_id 和 app_secret
|
||||
url: https://open.larkoffice.com/app
|
||||
app_secret:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: APP Secret
|
||||
placeholder:
|
||||
en_US: Please input your app secret
|
||||
zh_Hans: 请输入你的飞书 app secret
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class AddEventAttendeesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
|
||||
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,54 @@
|
|||
identity:
|
||||
name: add_event_attendees
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Add Event Attendees
|
||||
zh_Hans: 添加日程参会人
|
||||
description:
|
||||
human:
|
||||
en_US: Add Event Attendees
|
||||
zh_Hans: 添加日程参会人
|
||||
llm: A tool for adding attendees to events in Feishu. (在飞书中添加日程参会人)
|
||||
parameters:
|
||||
- name: event_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Event ID
|
||||
zh_Hans: 日程 ID
|
||||
human_description:
|
||||
en_US: |
|
||||
The ID of the event, which will be returned when the event is created. For example: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0.
|
||||
zh_Hans: |
|
||||
创建日程时会返回日程 ID。例如: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0。
|
||||
llm_description: |
|
||||
日程 ID,创建日程时会返回日程 ID。例如: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0。
|
||||
form: llm
|
||||
|
||||
- name: need_notification
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Need Notification
|
||||
zh_Hans: 是否需要通知
|
||||
human_description:
|
||||
en_US: |
|
||||
Whether to send a Bot notification to attendees. true: send, false: do not send.
|
||||
zh_Hans: |
|
||||
是否给参与人发送 Bot 通知,true: 发送,false: 不发送。
|
||||
llm_description: |
|
||||
是否给参与人发送 Bot 通知,true: 发送,false: 不发送。
|
||||
form: form
|
||||
|
||||
- name: attendee_phone_or_email
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Attendee Phone or Email
|
||||
zh_Hans: 参会人电话或邮箱
|
||||
human_description:
|
||||
en_US: The list of attendee emails or phone numbers, separated by commas.
|
||||
zh_Hans: 日程参会人邮箱或者手机号列表,使用逗号分隔。
|
||||
llm_description: 日程参会人邮箱或者手机号列表,使用逗号分隔。
|
||||
form: llm
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class CreateEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
summary = tool_parameters.get("summary")
|
||||
description = tool_parameters.get("description")
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
attendee_ability = tool_parameters.get("attendee_ability")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
auto_record = tool_parameters.get("auto_record", False)
|
||||
|
||||
res = client.create_event(
|
||||
summary, description, start_time, end_time, attendee_ability, need_notification, auto_record
|
||||
)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,119 @@
|
|||
identity:
|
||||
name: create_event
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Create Event
|
||||
zh_Hans: 创建日程
|
||||
description:
|
||||
human:
|
||||
en_US: Create Event
|
||||
zh_Hans: 创建日程
|
||||
llm: A tool for creating events in Feishu.(创建飞书日程)
|
||||
parameters:
|
||||
- name: summary
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Summary
|
||||
zh_Hans: 日程标题
|
||||
human_description:
|
||||
en_US: The title of the event. If not filled, the event title will display (No Subject).
|
||||
zh_Hans: 日程标题,若不填则日程标题显示 (无主题)。
|
||||
llm_description: 日程标题,若不填则日程标题显示 (无主题)。
|
||||
form: llm
|
||||
|
||||
- name: description
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Description
|
||||
zh_Hans: 日程描述
|
||||
human_description:
|
||||
en_US: The description of the event.
|
||||
zh_Hans: 日程描述。
|
||||
llm_description: 日程描述。
|
||||
form: llm
|
||||
|
||||
- name: need_notification
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Need Notification
|
||||
zh_Hans: 是否发送通知
|
||||
human_description:
|
||||
en_US: |
|
||||
Whether to send a bot message when the event is created, true: send, false: do not send.
|
||||
zh_Hans: 创建日程时是否发送 bot 消息,true:发送,false:不发送。
|
||||
llm_description: 创建日程时是否发送 bot 消息,true:发送,false:不发送。
|
||||
form: form
|
||||
|
||||
- name: start_time
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Start Time
|
||||
zh_Hans: 开始时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The start time of the event, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 日程开始时间,格式:2006-01-02 15:04:05。
|
||||
llm_description: 日程开始时间,格式:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: end_time
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: End Time
|
||||
zh_Hans: 结束时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The end time of the event, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 日程结束时间,格式:2006-01-02 15:04:05。
|
||||
llm_description: 日程结束时间,格式:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: attendee_ability
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: none
|
||||
label:
|
||||
en_US: none
|
||||
zh_Hans: 无
|
||||
- value: can_see_others
|
||||
label:
|
||||
en_US: can_see_others
|
||||
zh_Hans: 可以查看参与人列表
|
||||
- value: can_invite_others
|
||||
label:
|
||||
en_US: can_invite_others
|
||||
zh_Hans: 可以邀请其它参与人
|
||||
- value: can_modify_event
|
||||
label:
|
||||
en_US: can_modify_event
|
||||
zh_Hans: 可以编辑日程
|
||||
default: "none"
|
||||
label:
|
||||
en_US: attendee_ability
|
||||
zh_Hans: 参会人权限
|
||||
human_description:
|
||||
en_US: Attendee ability, optional values are none, can_see_others, can_invite_others, can_modify_event, with a default value of none.
|
||||
zh_Hans: 参会人权限,可选值有无、可以查看参与人列表、可以邀请其它参与人、可以编辑日程,默认值为无。
|
||||
llm_description: 参会人权限,可选值有无、可以查看参与人列表、可以邀请其它参与人、可以编辑日程,默认值为无。
|
||||
form: form
|
||||
|
||||
- name: auto_record
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Auto Record
|
||||
zh_Hans: 自动录制
|
||||
human_description:
|
||||
en_US: |
|
||||
Whether to enable automatic recording, true: enabled, automatically record when the meeting starts; false: not enabled.
|
||||
zh_Hans: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。
|
||||
llm_description: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。
|
||||
form: form
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class DeleteEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
|
||||
res = client.delete_event(event_id, need_notification)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,38 @@
|
|||
identity:
|
||||
name: delete_event
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Delete Event
|
||||
zh_Hans: 删除日程
|
||||
description:
|
||||
human:
|
||||
en_US: Delete Event
|
||||
zh_Hans: 删除日程
|
||||
llm: A tool for deleting events in Feishu.(在飞书中删除日程)
|
||||
parameters:
|
||||
- name: event_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Event ID
|
||||
zh_Hans: 日程 ID
|
||||
human_description:
|
||||
en_US: |
|
||||
The ID of the event, for example: e8b9791c-39ae-4908-8ad8-66b13159b9fb_0.
|
||||
zh_Hans: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。
|
||||
llm_description: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。
|
||||
form: llm
|
||||
|
||||
- name: need_notification
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Need Notification
|
||||
zh_Hans: 是否需要通知
|
||||
human_description:
|
||||
en_US: |
|
||||
Indicates whether to send bot notifications to event participants upon deletion. true: send, false: do not send.
|
||||
zh_Hans: 删除日程是否给日程参与人发送 bot 通知,true:发送,false:不发送。
|
||||
llm_description: 删除日程是否给日程参与人发送 bot 通知,true:发送,false:不发送。
|
||||
form: form
|
|
@ -0,0 +1,18 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class GetPrimaryCalendarTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
|
||||
res = client.get_primary_calendar(user_id_type)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,37 @@
|
|||
identity:
|
||||
name: get_primary_calendar
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Get Primary Calendar
|
||||
zh_Hans: 查询主日历信息
|
||||
description:
|
||||
human:
|
||||
en_US: Get Primary Calendar
|
||||
zh_Hans: 查询主日历信息
|
||||
llm: A tool for querying primary calendar information in Feishu.(在飞书中查询主日历信息)
|
||||
parameters:
|
||||
- name: user_id_type
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: open_id
|
||||
label:
|
||||
en_US: open_id
|
||||
zh_Hans: open_id
|
||||
- value: union_id
|
||||
label:
|
||||
en_US: union_id
|
||||
zh_Hans: union_id
|
||||
- value: user_id
|
||||
label:
|
||||
en_US: user_id
|
||||
zh_Hans: user_id
|
||||
default: "open_id"
|
||||
label:
|
||||
en_US: user_id_type
|
||||
zh_Hans: 用户 ID 类型
|
||||
human_description:
|
||||
en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id.
|
||||
zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
form: form
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class ListEventsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
page_token = tool_parameters.get("page_token")
|
||||
page_size = tool_parameters.get("page_size")
|
||||
|
||||
res = client.list_events(start_time, end_time, page_token, page_size)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,62 @@
|
|||
identity:
|
||||
name: list_events
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: List Events
|
||||
zh_Hans: 获取日程列表
|
||||
description:
|
||||
human:
|
||||
en_US: List Events
|
||||
zh_Hans: 获取日程列表
|
||||
llm: A tool for listing events in Feishu.(在飞书中获取日程列表)
|
||||
parameters:
|
||||
- name: start_time
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Start Time
|
||||
zh_Hans: 开始时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The start time, defaults to 0:00 of the current day if not provided, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。
|
||||
llm_description: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: end_time
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: End Time
|
||||
zh_Hans: 结束时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The end time, defaults to 23:59 of the current day if not provided, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。
|
||||
llm_description: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: page_size
|
||||
type: number
|
||||
required: false
|
||||
default: 50
|
||||
label:
|
||||
en_US: Page Size
|
||||
zh_Hans: 分页大小
|
||||
human_description:
|
||||
en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 50, and the value range is [50,1000].
|
||||
zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。
|
||||
llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。
|
||||
form: llm
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Page Token
|
||||
zh_Hans: 分页标记
|
||||
human_description:
|
||||
en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal.
|
||||
zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。
|
||||
llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。
|
||||
form: llm
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class SearchEventsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
query = tool_parameters.get("query")
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
page_token = tool_parameters.get("page_token")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
page_size = tool_parameters.get("page_size", 20)
|
||||
|
||||
res = client.search_events(query, start_time, end_time, page_token, user_id_type, page_size)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,100 @@
|
|||
identity:
|
||||
name: search_events
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Search Events
|
||||
zh_Hans: 搜索日程
|
||||
description:
|
||||
human:
|
||||
en_US: Search Events
|
||||
zh_Hans: 搜索日程
|
||||
llm: A tool for searching events in Feishu.(在飞书中搜索日程)
|
||||
parameters:
|
||||
- name: user_id_type
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: open_id
|
||||
label:
|
||||
en_US: open_id
|
||||
zh_Hans: open_id
|
||||
- value: union_id
|
||||
label:
|
||||
en_US: union_id
|
||||
zh_Hans: union_id
|
||||
- value: user_id
|
||||
label:
|
||||
en_US: user_id
|
||||
zh_Hans: user_id
|
||||
default: "open_id"
|
||||
label:
|
||||
en_US: user_id_type
|
||||
zh_Hans: 用户 ID 类型
|
||||
human_description:
|
||||
en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id.
|
||||
zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
form: form
|
||||
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query
|
||||
zh_Hans: 搜索关键字
|
||||
human_description:
|
||||
en_US: The search keyword used for fuzzy searching event names, with a maximum input of 200 characters.
|
||||
zh_Hans: 用于模糊查询日程名称的搜索关键字,最大输入 200 字符。
|
||||
llm_description: 用于模糊查询日程名称的搜索关键字,最大输入 200 字符。
|
||||
form: llm
|
||||
|
||||
- name: start_time
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Start Time
|
||||
zh_Hans: 开始时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The start time, defaults to 0:00 of the current day if not provided, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。
|
||||
llm_description: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: end_time
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: End Time
|
||||
zh_Hans: 结束时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The end time, defaults to 23:59 of the current day if not provided, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。
|
||||
llm_description: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: page_size
|
||||
type: number
|
||||
required: false
|
||||
default: 20
|
||||
label:
|
||||
en_US: Page Size
|
||||
zh_Hans: 分页大小
|
||||
human_description:
|
||||
en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [10,100].
|
||||
zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。
|
||||
llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。
|
||||
form: llm
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Page Token
|
||||
zh_Hans: 分页标记
|
||||
human_description:
|
||||
en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal.
|
||||
zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。
|
||||
llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。
|
||||
form: llm
|
|
@ -0,0 +1,24 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class UpdateEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
summary = tool_parameters.get("summary")
|
||||
description = tool_parameters.get("description")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
auto_record = tool_parameters.get("auto_record", False)
|
||||
|
||||
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)
|
||||
|
||||
return self.create_json_message(res)
|
|
@ -0,0 +1,100 @@
|
|||
identity:
|
||||
name: update_event
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Update Event
|
||||
zh_Hans: 更新日程
|
||||
description:
|
||||
human:
|
||||
en_US: Update Event
|
||||
zh_Hans: 更新日程
|
||||
llm: A tool for updating events in Feishu.(更新飞书中的日程)
|
||||
parameters:
|
||||
- name: event_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Event ID
|
||||
zh_Hans: 日程 ID
|
||||
human_description:
|
||||
en_US: |
|
||||
The ID of the event, for example: e8b9791c-39ae-4908-8ad8-66b13159b9fb_0.
|
||||
zh_Hans: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。
|
||||
llm_description: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。
|
||||
form: llm
|
||||
|
||||
- name: summary
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Summary
|
||||
zh_Hans: 日程标题
|
||||
human_description:
|
||||
en_US: The title of the event.
|
||||
zh_Hans: 日程标题。
|
||||
llm_description: 日程标题。
|
||||
form: llm
|
||||
|
||||
- name: description
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Description
|
||||
zh_Hans: 日程描述
|
||||
human_description:
|
||||
en_US: The description of the event.
|
||||
zh_Hans: 日程描述。
|
||||
llm_description: 日程描述。
|
||||
form: llm
|
||||
|
||||
- name: need_notification
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Need Notification
|
||||
zh_Hans: 是否发送通知
|
||||
human_description:
|
||||
en_US: |
|
||||
Whether to send a bot message when the event is updated, true: send, false: do not send.
|
||||
zh_Hans: 更新日程时是否发送 bot 消息,true:发送,false:不发送。
|
||||
llm_description: 更新日程时是否发送 bot 消息,true:发送,false:不发送。
|
||||
form: form
|
||||
|
||||
- name: start_time
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Start Time
|
||||
zh_Hans: 开始时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The start time of the event, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 日程开始时间,格式:2006-01-02 15:04:05。
|
||||
llm_description: 日程开始时间,格式:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: end_time
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: End Time
|
||||
zh_Hans: 结束时间
|
||||
human_description:
|
||||
en_US: |
|
||||
The end time of the event, format: 2006-01-02 15:04:05.
|
||||
zh_Hans: 日程结束时间,格式:2006-01-02 15:04:05。
|
||||
llm_description: 日程结束时间,格式:2006-01-02 15:04:05。
|
||||
form: llm
|
||||
|
||||
- name: auto_record
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Auto Record
|
||||
zh_Hans: 自动录制
|
||||
human_description:
|
||||
en_US: |
|
||||
Whether to enable automatic recording, true: enabled, automatically record when the meeting starts; false: not enabled.
|
||||
zh_Hans: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。
|
||||
llm_description: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。
|
||||
form: form
|
|
@ -1,15 +1,7 @@
|
|||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
from core.tools.utils.feishu_api_utils import auth
|
||||
|
||||
|
||||
class FeishuDocumentProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
app_id = credentials.get("app_id")
|
||||
app_secret = credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
|
||||
try:
|
||||
assert FeishuRequest(app_id, app_secret).tenant_access_token is not None
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
auth(credentials)
|
||||
|
|
|
@ -5,8 +5,10 @@ identity:
|
|||
en_US: Lark Cloud Document
|
||||
zh_Hans: 飞书云文档
|
||||
description:
|
||||
en_US: Lark Cloud Document
|
||||
zh_Hans: 飞书云文档
|
||||
en_US: |
|
||||
Lark cloud document, requires the following permissions: docx:document、drive:drive、docs:document.content:read.
|
||||
zh_Hans: |
|
||||
飞书云文档,需要开通以下权限: docx:document、drive:drive、docs:document.content:read。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- social
|
||||
|
@ -23,7 +25,7 @@ credentials_for_provider:
|
|||
help:
|
||||
en_US: Get your app_id and app_secret from Feishu
|
||||
zh_Hans: 从飞书获取您的 app_id 和 app_secret
|
||||
url: https://open.feishu.cn
|
||||
url: https://open.larkoffice.com/app
|
||||
app_secret:
|
||||
type: secret-input
|
||||
required: true
|
||||
|
|
|
@ -7,7 +7,7 @@ identity:
|
|||
description:
|
||||
human:
|
||||
en_US: Create Lark document
|
||||
zh_Hans: 创建飞书文档,支持创建空文档和带内容的文档,支持 markdown 语法创建。
|
||||
zh_Hans: 创建飞书文档,支持创建空文档和带内容的文档,支持 markdown 语法创建。应用需要开启机器人能力(https://open.feishu.cn/document/faq/trouble-shooting/how-to-enable-bot-ability)。
|
||||
llm: A tool for creating Feishu documents.
|
||||
parameters:
|
||||
- name: title
|
||||
|
@ -41,7 +41,8 @@ parameters:
|
|||
en_US: folder_token
|
||||
zh_Hans: 文档所在文件夹的 Token
|
||||
human_description:
|
||||
en_US: The token of the folder where the document is located. If it is not passed or is empty, it means the root directory.
|
||||
zh_Hans: 文档所在文件夹的 Token,不传或传空表示根目录。
|
||||
llm_description: 文档所在文件夹的 Token,不传或传空表示根目录。
|
||||
en_US: |
|
||||
The token of the folder where the document is located. If it is not passed or is empty, it means the root directory. For Example: https://svi136aogf123.feishu.cn/drive/folder/JgR9fiG9AlPt8EdsSNpcGjIInbf
|
||||
zh_Hans: 文档所在文件夹的 Token,不传或传空表示根目录。例如:https://svi136aogf123.feishu.cn/drive/folder/JgR9fiG9AlPt8EdsSNpcGjIInbf。
|
||||
llm_description: 文档所在文件夹的 Token,不传或传空表示根目录。例如:https://svi136aogf123.feishu.cn/drive/folder/JgR9fiG9AlPt8EdsSNpcGjIInbf。
|
||||
form: llm
|
||||
|
|
|
@ -12,8 +12,8 @@ class GetDocumentRawContentTool(BuiltinTool):
|
|||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get("document_id")
|
||||
mode = tool_parameters.get("mode")
|
||||
lang = tool_parameters.get("lang", 0)
|
||||
mode = tool_parameters.get("mode", "markdown")
|
||||
lang = tool_parameters.get("lang", "0")
|
||||
|
||||
res = client.get_document_content(document_id, mode, lang)
|
||||
return self.create_json_message(res)
|
||||
|
|
|
@ -23,8 +23,18 @@ parameters:
|
|||
form: llm
|
||||
|
||||
- name: mode
|
||||
type: string
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: text
|
||||
- value: markdown
|
||||
label:
|
||||
en_US: markdown
|
||||
zh_Hans: markdown
|
||||
default: "markdown"
|
||||
label:
|
||||
en_US: mode
|
||||
zh_Hans: 文档返回格式
|
||||
|
@ -32,18 +42,29 @@ parameters:
|
|||
en_US: Format of the document return, optional values are text, markdown, can be empty, default is markdown.
|
||||
zh_Hans: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。
|
||||
llm_description: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: lang
|
||||
type: number
|
||||
type: select
|
||||
required: false
|
||||
default: 0
|
||||
options:
|
||||
- value: "0"
|
||||
label:
|
||||
en_US: User's default name
|
||||
zh_Hans: 用户的默认名称
|
||||
- value: "1"
|
||||
label:
|
||||
en_US: User's English name
|
||||
zh_Hans: 用户的英文名称
|
||||
default: "0"
|
||||
label:
|
||||
en_US: lang
|
||||
zh_Hans: 指定@用户的语言
|
||||
human_description:
|
||||
en_US: |
|
||||
Specifies the language for MentionUser, optional values are [0, 1]. 0: User's default name, 1: User's English name, default is 0.
|
||||
zh_Hans: 指定返回的 MentionUser,即 @用户 的语言,可选值有 [0,1]。0:该用户的默认名称,1:该用户的英文名称,默认值为 0。
|
||||
llm_description: 指定返回的 MentionUser,即 @用户 的语言,可选值有 [0,1]。0:该用户的默认名称,1:该用户的英文名称,默认值为 0。
|
||||
form: llm
|
||||
zh_Hans: |
|
||||
指定返回的 MentionUser,即@用户的语言,可选值有 [0,1]。0: 该用户的默认名称,1: 该用户的英文名称,默认值为 0。
|
||||
llm_description: |
|
||||
指定返回的 MentionUser,即@用户的语言,可选值有 [0,1]。0: 该用户的默认名称,1: 该用户的英文名称,默认值为 0。
|
||||
form: form
|
||||
|
|
|
@ -12,8 +12,9 @@ class ListDocumentBlockTool(BuiltinTool):
|
|||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get("document_id")
|
||||
page_size = tool_parameters.get("page_size", 500)
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
page_size = tool_parameters.get("page_size", 500)
|
||||
|
||||
res = client.list_document_blocks(document_id, page_token, page_size)
|
||||
res = client.list_document_blocks(document_id, page_token, user_id_type, page_size)
|
||||
return self.create_json_message(res)
|
||||
|
|
|
@ -46,12 +46,12 @@ parameters:
|
|||
en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id.
|
||||
zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_size
|
||||
type: number
|
||||
required: false
|
||||
default: "500"
|
||||
default: 500
|
||||
label:
|
||||
en_US: page_size
|
||||
zh_Hans: 分页大小
|
||||
|
|
|
@ -13,7 +13,7 @@ class CreateDocumentTool(BuiltinTool):
|
|||
|
||||
document_id = tool_parameters.get("document_id")
|
||||
content = tool_parameters.get("content")
|
||||
position = tool_parameters.get("position")
|
||||
position = tool_parameters.get("position", "end")
|
||||
|
||||
res = client.write_document(document_id, content, position)
|
||||
return self.create_json_message(res)
|
||||
|
|