mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge branch 'main' into feat/new-login
* main: (35 commits) fix https://github.com/langgenius/dify/issues/9409 (#9433) update dataset clean rule (#9426) add clean 7 days datasets (#9424) fix: resolve overlap issue with API Extension selector and modal (#9407) refactor: update the default values of top-k parameter in vdb to be consistent (#9367) fix: incorrect webapp image displayed (#9401) Fix/economical knowledge retrieval (#9396) feat: add timezone conversion for time tool (#9393) fix: Deprecated gemma2-9b model in Fireworks AI Provider (#9373) feat: storybook (#9324) fix: use gpt-4o-mini for validating credentials (#9387) feat: Enable baiduvector intergration test (#9369) fix: remove the stream option of zhipu and gemini (#9319) fix: add missing vikingdb param in docker .env.example (#9334) feat: add minimax abab6.5t support (#9365) fix: (#9336 followup) skip poetry preperation in style workflow when no change in api folder (#9362) feat: add glm-4-flashx, deprecated chatglm_turbo (#9357) fix: Azure OpenAI o1 max_completion_token and get_num_token_from_messages error (#9326) fix: In the output, the order of 'ta' is sometimes reversed as 'at'. #8015 (#8791) refactor: Add an enumeration type and use the factory pattern to obtain the corresponding class (#9356) ...
This commit is contained in:
commit
abe6a1bb99
7
.github/workflows/api-tests.yml
vendored
7
.github/workflows/api-tests.yml
vendored
|
@ -27,18 +27,17 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Check Poetry lockfile
|
||||
run: |
|
||||
poetry check -C api --lock
|
||||
|
|
7
.github/workflows/db-migration-test.yml
vendored
7
.github/workflows/db-migration-test.yml
vendored
|
@ -23,18 +23,17 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -C api
|
||||
|
||||
|
|
7
.github/workflows/style.yml
vendored
7
.github/workflows/style.yml
vendored
|
@ -24,15 +24,16 @@ jobs:
|
|||
with:
|
||||
files: api/**
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Python dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry install -C api --only lint
|
||||
|
|
|
@ -85,3 +85,4 @@
|
|||
cd ../
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
|
|
@ -531,11 +531,16 @@ class DataSetConfig(BaseSettings):
|
|||
Configuration for dataset management
|
||||
"""
|
||||
|
||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for dataset cleanup operations",
|
||||
PLAN_SANDBOX_CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for dataset cleanup operations - plan: sandbox",
|
||||
default=30,
|
||||
)
|
||||
|
||||
PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for dataset cleanup operations - plan: pro and team",
|
||||
default=7,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description="Enable or disable dataset operator functionality",
|
||||
default=False,
|
||||
|
|
|
@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
ORACLE_PORT: PositiveInt = Field(
|
||||
description="Port number on which the Oracle database server is listening (default is 1521)",
|
||||
default=1521,
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
PGVECTOR_PORT: PositiveInt = Field(
|
||||
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
||||
default=5433,
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
PGVECTO_RS_PORT: PositiveInt = Field(
|
||||
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
||||
default=5431,
|
||||
)
|
||||
|
|
|
@ -11,27 +11,39 @@ class VikingDBConfig(BaseModel):
|
|||
"""
|
||||
|
||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||
description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||
"Refer to the following documentation for details on obtaining credentials:"
|
||||
"https://www.volcengine.com/docs/6291/65568",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
|
||||
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
|
||||
default=None,
|
||||
)
|
||||
VIKINGDB_REGION: Optional[str] = Field(
|
||||
default="cn-shanghai",
|
||||
|
||||
VIKINGDB_REGION: str = Field(
|
||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||
default="cn-shanghai",
|
||||
)
|
||||
VIKINGDB_HOST: Optional[str] = Field(
|
||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||
|
||||
VIKINGDB_HOST: str = Field(
|
||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||
)
|
||||
VIKINGDB_SCHEME: Optional[str] = Field(
|
||||
default="http",
|
||||
|
||||
VIKINGDB_SCHEME: str = Field(
|
||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||
default="http",
|
||||
)
|
||||
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
|
||||
default=30, description="The connection timeout of the Volcengine VikingDB service."
|
||||
|
||||
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
|
||||
description="The connection timeout of the Volcengine VikingDB service.",
|
||||
default=30,
|
||||
)
|
||||
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
|
||||
default=30, description="The socket timeout of the Volcengine VikingDB service."
|
||||
|
||||
VIKINGDB_SOCKET_TIMEOUT: int = Field(
|
||||
description="The socket timeout of the Volcengine VikingDB service.",
|
||||
default=30,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.9.1",
|
||||
default="0.9.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
@ -1,88 +1,24 @@
|
|||
import logging
|
||||
from flask_restful import Resource
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
|
|
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
|
@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
|
|||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, segment
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
|
|
|
@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse
|
|||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from constants import UUID_NIL
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
|
@ -108,7 +107,6 @@ class ChatApi(Resource):
|
|||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
|
||||
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
|
@ -62,6 +62,8 @@ class CotAgentOutputParser:
|
|||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
last_character = ""
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
|
@ -74,35 +76,38 @@ class CotAgentOutputParser:
|
|||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
yield_delta = False
|
||||
|
||||
if delta == "`":
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
last_character = delta
|
||||
yield code_block_cache
|
||||
code_block_cache = ""
|
||||
else:
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
yield_delta = True
|
||||
else:
|
||||
last_character = delta
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
||||
last_character = delta
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
|
@ -112,24 +117,25 @@ class CotAgentOutputParser:
|
|||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
last_character = delta
|
||||
yield action_cache
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
yield_delta = True
|
||||
else:
|
||||
last_character = delta
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
||||
last_character = delta
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
|
@ -139,12 +145,20 @@ class CotAgentOutputParser:
|
|||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
last_character = delta
|
||||
yield thought_cache
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if yield_delta:
|
||||
index += steps
|
||||
last_character = delta
|
||||
yield delta
|
||||
continue
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
last_character = delta
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ""
|
||||
|
||||
|
@ -156,8 +170,10 @@ class CotAgentOutputParser:
|
|||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
elif delta == "}":
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
|
@ -168,16 +184,19 @@ class CotAgentOutputParser:
|
|||
continue
|
||||
else:
|
||||
if in_json:
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
last_character = delta
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
last_character = delta
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
|
|
@ -10,6 +10,7 @@ from flask import Flask, current_app
|
|||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
|
@ -122,7 +123,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
|||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
|
@ -127,7 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
|||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
|
@ -128,7 +129,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
@ -2,8 +2,9 @@ from collections.abc import Mapping
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileVar
|
||||
|
@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Base entity for conversation-based app generation.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||
"For service API, we need to ensure its forward compatibility, "
|
||||
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("parent_message_id")
|
||||
@classmethod
|
||||
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||
return v
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
pass
|
||||
|
||||
|
||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
|
@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
pass
|
||||
|
||||
|
||||
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Agent Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
"""
|
||||
|
@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
workflow_run_id: Optional[str] = None
|
||||
query: str
|
||||
|
||||
|
|
|
@ -1098,6 +1098,14 @@ LLM_BASE_MODELS = [
|
|||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
|
@ -1135,6 +1143,14 @@ LLM_BASE_MODELS = [
|
|||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
|
|
|
@ -119,7 +119,15 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
if model.startswith("o1"):
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=1,
|
||||
max_completion_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
|
|
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
|
|
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
|
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
|
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
|
|
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.075'
|
||||
|
|
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
|
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
|
|
@ -18,6 +18,7 @@ supported_model_types:
|
|||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: fireworks_api_key
|
||||
|
@ -28,3 +29,75 @@ provider_credential_schema:
|
|||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model URL
|
||||
zh_Hans: 模型URL
|
||||
placeholder:
|
||||
en_US: Enter your Model URL
|
||||
zh_Hans: 输入模型URL
|
||||
credential_form_schemas:
|
||||
- variable: model_label_zh_Hanns
|
||||
label:
|
||||
zh_Hans: 模型中文名称
|
||||
en_US: The zh_Hans of Model
|
||||
required: true
|
||||
type: text-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型中文名称
|
||||
en_US: Enter your zh_Hans of Model
|
||||
- variable: model_label_en_US
|
||||
label:
|
||||
zh_Hans: 模型英文名称
|
||||
en_US: The en_US of Model
|
||||
required: true
|
||||
type: text-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型英文名称
|
||||
en_US: Enter your en_US of Model
|
||||
- variable: fireworks_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '4096'
|
||||
type: text-input
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
|
|
|
@ -43,3 +43,4 @@ pricing:
|
|||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
|
|
@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
|||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
|
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||
|
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
|||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=credentials.get("model_label_en_US", model),
|
||||
zh_Hans=credentials.get("model_label_zh_Hanns", model),
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get("function_calling_type") == "function_call"
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
use_template="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
use_template="top_k",
|
||||
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
model: accounts/fireworks/models/qwen2p5-72b-instruct
|
||||
label:
|
||||
zh_Hans: Qwen2.5 72B Instruct
|
||||
en_US: Qwen2.5 72B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.9'
|
||||
output: '0.9'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -27,15 +27,6 @@ parameter_rules:
|
|||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -31,15 +31,6 @@ parameter_rules:
|
|||
max: 2048
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
model: abab6.5t-chat
|
||||
label:
|
||||
en_US: Abab6.5t-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.9
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 3072
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: mask_sensitive_info
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
zh_Hans: 隐私保护
|
||||
en_US: Moderate
|
||||
help:
|
||||
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
|
@ -61,7 +61,8 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||
url = f"{self.api_base}?GroupId={group_id}"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
||||
embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query"
|
||||
data = {"model": "embo-01", "texts": texts, "type": embedding_type}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
|
|
|
@ -19,9 +19,9 @@ class OpenAIProvider(ModelProvider):
|
|||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gpt-3.5-turbo` model for validate,
|
||||
# Use `gpt-4o-mini` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials)
|
||||
model_instance.validate_credentials(model="gpt-4o-mini", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
|
|
@ -28,15 +28,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: return_type
|
||||
label:
|
||||
zh_Hans: 回复类型
|
||||
|
@ -49,3 +40,4 @@ parameter_rules:
|
|||
options:
|
||||
- text
|
||||
- json_string
|
||||
deprecated: true
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
model: glm-4-flashx
|
||||
label:
|
||||
en_US: glm-4-flashx
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: do_sample
|
||||
label:
|
||||
zh_Hans: 采样策略
|
||||
en_US: Sampling strategy
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 4095
|
||||
- name: web_search
|
||||
type: boolean
|
||||
label:
|
||||
zh_Hans: 联网搜索
|
||||
en_US: Web Search
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -35,15 +35,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -30,15 +30,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -30,15 +30,6 @@ parameter_rules:
|
|||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
@ -16,9 +20,6 @@ from core.model_runtime.entities.message_entities import (
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
|
||||
|
||||
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
from .__version__ import __version__
|
||||
from ._client import ZhipuAI
|
||||
from .core import (
|
||||
APIAuthenticationError,
|
||||
APIConnectionError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
|
@ -1 +0,0 @@
|
|||
__version__ = "v2.1.0"
|
|
@ -1,82 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
from httpx import Timeout
|
||||
from typing_extensions import override
|
||||
|
||||
from . import api_resource
|
||||
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
|
||||
|
||||
|
||||
class ZhipuAI(HttpClient):
|
||||
chat: api_resource.chat.Chat
|
||||
api_key: str
|
||||
_disable_token_cache: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
http_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None,
|
||||
disable_token_cache: bool = True,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
||||
self.api_key = api_key
|
||||
self._disable_token_cache = disable_token_cache
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
||||
from .__version__ import __version__
|
||||
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
custom_httpx_client=http_client,
|
||||
custom_headers=custom_headers,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
self.chat = api_resource.chat.Chat(self)
|
||||
self.images = api_resource.images.Images(self)
|
||||
self.embeddings = api_resource.embeddings.Embeddings(self)
|
||||
self.files = api_resource.files.Files(self)
|
||||
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
||||
self.batches = api_resource.Batches(self)
|
||||
self.knowledge = api_resource.Knowledge(self)
|
||||
self.tools = api_resource.Tools(self)
|
||||
self.videos = api_resource.Videos(self)
|
||||
self.assistant = api_resource.Assistant(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
if self._disable_token_cache:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
|
||||
# if the '__init__' method raised an error, self would not have client attr
|
||||
return
|
||||
|
||||
if self._has_custom_http_client:
|
||||
return
|
||||
|
||||
self.close()
|
|
@ -1,34 +0,0 @@
|
|||
from .assistant import (
|
||||
Assistant,
|
||||
)
|
||||
from .batches import Batches
|
||||
from .chat import (
|
||||
AsyncCompletions,
|
||||
Chat,
|
||||
Completions,
|
||||
)
|
||||
from .embeddings import Embeddings
|
||||
from .files import Files, FilesWithRawResponse
|
||||
from .fine_tuning import FineTuning
|
||||
from .images import Images
|
||||
from .knowledge import Knowledge
|
||||
from .tools import Tools
|
||||
from .videos import (
|
||||
Videos,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Videos",
|
||||
"AsyncCompletions",
|
||||
"Chat",
|
||||
"Completions",
|
||||
"Images",
|
||||
"Embeddings",
|
||||
"Files",
|
||||
"FilesWithRawResponse",
|
||||
"FineTuning",
|
||||
"Batches",
|
||||
"Knowledge",
|
||||
"Tools",
|
||||
"Assistant",
|
||||
]
|
|
@ -1,3 +0,0 @@
|
|||
from .assistant import Assistant
|
||||
|
||||
__all__ = ["Assistant"]
|
|
@ -1,122 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.assistant import AssistantCompletion
|
||||
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
|
||||
from ...types.assistant.assistant_support_resp import AssistantSupportResp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
from ...types.assistant import assistant_conversation_params, assistant_create_params
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
|
||||
|
||||
class Assistant(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def conversation(
|
||||
self,
|
||||
assistant_id: str,
|
||||
model: str,
|
||||
messages: list[assistant_create_params.ConversationMessage],
|
||||
*,
|
||||
stream: bool = True,
|
||||
conversation_id: Optional[str] = None,
|
||||
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
||||
metadata: dict | None = None,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> StreamResponse[AssistantCompletion]:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"conversation_id": conversation_id,
|
||||
"attachments": attachments,
|
||||
"metadata": metadata,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant",
|
||||
body=maybe_transform(body, assistant_create_params.AssistantParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantCompletion,
|
||||
stream=stream or True,
|
||||
stream_cls=StreamResponse[AssistantCompletion],
|
||||
)
|
||||
|
||||
def query_support(
|
||||
self,
|
||||
*,
|
||||
assistant_id_list: Optional[list[str]] = None,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AssistantSupportResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id_list": assistant_id_list,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/list",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantSupportResp,
|
||||
)
|
||||
|
||||
def query_conversation_usage(
|
||||
self,
|
||||
assistant_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
*,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationUsageListResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/conversation/list",
|
||||
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=ConversationUsageListResp,
|
||||
)
|
|
@ -1,146 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
|
||||
from ..core.pagination import SyncCursorPage
|
||||
from ..types import batch_create_params, batch_list_params
|
||||
from ..types.batch import Batch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Batches(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
completion_window: str | None = None,
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
|
||||
input_file_id: str,
|
||||
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
auto_delete_input_file: bool = True,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
return self._post(
|
||||
"/batches",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"completion_window": completion_window,
|
||||
"endpoint": endpoint,
|
||||
"input_file_id": input_file_id,
|
||||
"metadata": metadata,
|
||||
"auto_delete_input_file": auto_delete_input_file,
|
||||
},
|
||||
batch_create_params.BatchCreateParams,
|
||||
),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Retrieves a batch.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._get(
|
||||
f"/batches/{batch_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> SyncCursorPage[Batch]:
|
||||
"""List your organization's batches.
|
||||
|
||||
Args:
|
||||
after: A cursor for use in pagination.
|
||||
|
||||
`after` is an object ID that defines your place
|
||||
in the list. For instance, if you make a list request and receive 100 objects,
|
||||
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
||||
fetch the next page of the list.
|
||||
|
||||
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
||||
100, and the default is 20.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get_api_list(
|
||||
"/batches",
|
||||
page=SyncCursorPage[Batch],
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
batch_list_params.BatchListParams,
|
||||
),
|
||||
),
|
||||
model=Batch,
|
||||
)
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Cancels an in-progress batch.
|
||||
|
||||
Args:
|
||||
batch_id: The ID of the batch to cancel.
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._post(
|
||||
f"/batches/{batch_id}/cancel",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
|
@ -1,5 +0,0 @@
|
|||
from .async_completions import AsyncCompletions
|
||||
from .chat import Chat
|
||||
from .completions import Completions
|
||||
|
||||
__all__ = ["AsyncCompletions", "Chat", "Completions"]
|
|
@ -1,115 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
drop_prefix_image_data,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
|
||||
from ...types.chat.code_geex import code_geex_params
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class AsyncCompletions(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AsyncTaskStatus:
|
||||
_cast_type = AsyncTaskStatus
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if temperature is not None and temperature != NOT_GIVEN:
|
||||
if temperature <= 0:
|
||||
do_sample = False
|
||||
temperature = 0.01
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
||||
if temperature >= 1:
|
||||
temperature = 0.99
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
||||
if top_p is not None and top_p != NOT_GIVEN:
|
||||
if top_p >= 1:
|
||||
top_p = 0.99
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
if top_p <= 0:
|
||||
top_p = 0.01
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if item.get("content"):
|
||||
item["content"] = drop_prefix_image_data(item["content"])
|
||||
|
||||
body = {
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"meta": meta,
|
||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
||||
}
|
||||
return self._post(
|
||||
"/async/chat/completions",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def retrieve_completion_result(
|
||||
self,
|
||||
id: str,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
||||
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
|
||||
return self._get(
|
||||
path=f"/async-result/{id}",
|
||||
cast_type=_cast_type,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
)
|
|
@ -1,18 +0,0 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .async_completions import AsyncCompletions
|
||||
from .completions import Completions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class Chat(BaseAPI):
|
||||
@cached_property
|
||||
def completions(self) -> Completions:
|
||||
return Completions(self._client)
|
||||
|
||||
@cached_property
|
||||
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
|
||||
return AsyncCompletions(self._client)
|
|
@ -1,108 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
drop_prefix_image_data,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.chat.chat_completion import Completion
|
||||
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from ...types.chat.code_geex import code_geex_params
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Completions(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], object, None],
|
||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if temperature is not None and temperature != NOT_GIVEN:
|
||||
if temperature <= 0:
|
||||
do_sample = False
|
||||
temperature = 0.01
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
||||
if temperature >= 1:
|
||||
temperature = 0.99
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
||||
if top_p is not None and top_p != NOT_GIVEN:
|
||||
if top_p >= 1:
|
||||
top_p = 0.99
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
if top_p <= 0:
|
||||
top_p = 0.01
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if item.get("content"):
|
||||
item["content"] = drop_prefix_image_data(item["content"])
|
||||
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"meta": meta,
|
||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/chat/completions",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Completion,
|
||||
stream=stream or False,
|
||||
stream_cls=StreamResponse[ChatCompletionChunk],
|
||||
)
|
|
@ -1,50 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
||||
from ..types.embeddings import EmbeddingsResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Embeddings(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, list[str], list[int], list[list[int]]],
|
||||
model: Union[str],
|
||||
dimensions: Union[int] | NotGiven = NOT_GIVEN,
|
||||
encoding_format: str | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> EmbeddingsResponded:
|
||||
_cast_type = EmbeddingsResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/embeddings",
|
||||
body={
|
||||
"input": input,
|
||||
"model": model,
|
||||
"dimensions": dimensions,
|
||||
"encoding_format": encoding_format,
|
||||
"user": user,
|
||||
"request_id": request_id,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
|
@ -1,194 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
NotGiven,
|
||||
_legacy_binary_response,
|
||||
_legacy_response,
|
||||
deepcopy_minimal,
|
||||
extract_files,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
__all__ = ["Files", "FilesWithRawResponse"]
|
||||
|
||||
|
||||
class Files(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: Optional[FileTypes] = None,
|
||||
upload_detail: Optional[list[UploadDetail]] = None,
|
||||
purpose: Literal["fine-tune", "retrieval", "batch"],
|
||||
knowledge_id: Optional[str] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileObject:
|
||||
if not file and not upload_detail:
|
||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"file": file,
|
||||
"upload_detail": upload_detail,
|
||||
"purpose": purpose,
|
||||
"knowledge_id": knowledge_id,
|
||||
"sentence_size": sentence_size,
|
||||
}
|
||||
)
|
||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
||||
if files:
|
||||
# It should be noted that the actual Content-Type header that will be
|
||||
# sent to the server will contain a `boundary` parameter, e.g.
|
||||
# multipart/form-data; boundary=---abc--
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
return self._post(
|
||||
"/files",
|
||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
||||
files=files,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FileObject,
|
||||
)
|
||||
|
||||
# def retrieve(
|
||||
# self,
|
||||
# file_id: str,
|
||||
# *,
|
||||
# extra_headers: Headers | None = None,
|
||||
# extra_body: Body | None = None,
|
||||
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
# ) -> FileObject:
|
||||
# """
|
||||
# Returns information about a specific file.
|
||||
#
|
||||
# Args:
|
||||
# file_id: The ID of the file to retrieve information about
|
||||
# extra_headers: Send extra headers
|
||||
#
|
||||
# extra_body: Add additional JSON properties to the request
|
||||
#
|
||||
# timeout: Override the client-level default timeout for this request, in seconds
|
||||
# """
|
||||
# if not file_id:
|
||||
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
# return self._get(
|
||||
# f"/files/{file_id}",
|
||||
# options=make_request_options(
|
||||
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
|
||||
# ),
|
||||
# cast_type=FileObject,
|
||||
# )
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
order: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFileObject:
|
||||
return self._get(
|
||||
"/files",
|
||||
cast_type=ListOfFileObject,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"purpose": purpose,
|
||||
"limit": limit,
|
||||
"after": after,
|
||||
"order": order,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
file_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
file_id: The ID of the file to delete
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not file_id:
|
||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
return self._delete(
|
||||
f"/files/{file_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FileDeleted,
|
||||
)
|
||||
|
||||
def content(
|
||||
self,
|
||||
file_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> _legacy_response.HttpxBinaryResponseContent:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not file_id:
|
||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
|
||||
return self._get(
|
||||
f"/files/{file_id}/content",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
|
||||
)
|
||||
|
||||
|
||||
class FilesWithRawResponse:
|
||||
def __init__(self, files: Files) -> None:
|
||||
self._files = files
|
||||
|
||||
self.create = _legacy_response.to_raw_response_wrapper(
|
||||
files.create,
|
||||
)
|
||||
self.list = _legacy_response.to_raw_response_wrapper(
|
||||
files.list,
|
||||
)
|
||||
self.content = _legacy_response.to_raw_response_wrapper(
|
||||
files.content,
|
||||
)
|
|
@ -1,5 +0,0 @@
|
|||
from .fine_tuning import FineTuning
|
||||
from .jobs import Jobs
|
||||
from .models import FineTunedModels
|
||||
|
||||
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]
|
|
@ -1,18 +0,0 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .jobs import Jobs
|
||||
from .models import FineTunedModels
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class FineTuning(BaseAPI):
|
||||
@cached_property
|
||||
def jobs(self) -> Jobs:
|
||||
return Jobs(self._client)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> FineTunedModels:
|
||||
return FineTunedModels(self._client)
|
|
@ -1,3 +0,0 @@
|
|||
from .jobs import Jobs
|
||||
|
||||
__all__ = ["Jobs"]
|
|
@ -1,152 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
make_request_options,
|
||||
)
|
||||
from ....types.fine_tuning import (
|
||||
FineTuningJob,
|
||||
FineTuningJobEvent,
|
||||
ListOfFineTuningJob,
|
||||
job_create_params,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["Jobs"]
|
||||
|
||||
|
||||
class Jobs(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
|
||||
suffix: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._post(
|
||||
"/fine_tuning/jobs",
|
||||
body={
|
||||
"model": model,
|
||||
"training_file": training_file,
|
||||
"hyperparameters": hyperparameters,
|
||||
"suffix": suffix,
|
||||
"validation_file": validation_file,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFineTuningJob:
|
||||
return self._get(
|
||||
"/fine_tuning/jobs",
|
||||
cast_type=ListOfFineTuningJob,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
if not fine_tuning_job_id:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
|
||||
return self._post(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def list_events(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJobEvent:
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
|
||||
cast_type=FineTuningJobEvent,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
if not fine_tuning_job_id:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
|
||||
return self._delete(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
|
@ -1,3 +0,0 @@
|
|||
from .fine_tuned_models import FineTunedModels
|
||||
|
||||
__all__ = ["FineTunedModels"]
|
|
@ -1,41 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
make_request_options,
|
||||
)
|
||||
from ....types.fine_tuning.models import FineTunedModelsStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["FineTunedModels"]
|
||||
|
||||
|
||||
class FineTunedModels(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
fine_tuned_model: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTunedModelsStatus:
|
||||
if not fine_tuned_model:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
|
||||
return self._delete(
|
||||
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTunedModelsStatus,
|
||||
)
|
|
@ -1,59 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
||||
from ..types.image import ImagesResponded
|
||||
from ..types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Images(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def generations(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str | NotGiven = NOT_GIVEN,
|
||||
n: Optional[int] | NotGiven = NOT_GIVEN,
|
||||
quality: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
response_format: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
_cast_type = ImagesResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/images/generations",
|
||||
body={
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"n": n,
|
||||
"quality": quality,
|
||||
"response_format": response_format,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
"user_id": user_id,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
|
@ -1,3 +0,0 @@
|
|||
from .knowledge import Knowledge
|
||||
|
||||
__all__ = ["Knowledge"]
|
|
@ -1,3 +0,0 @@
|
|||
from .document import Document
|
||||
|
||||
__all__ = ["Document"]
|
|
@ -1,217 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
NotGiven,
|
||||
deepcopy_minimal,
|
||||
extract_files,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ....types.files import UploadDetail, file_create_params
|
||||
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
|
||||
from ....types.knowledge.document.document_list_resp import DocumentPage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["Document"]
|
||||
|
||||
|
||||
class Document(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: Optional[FileTypes] = None,
|
||||
custom_separator: Optional[list[str]] = None,
|
||||
upload_detail: Optional[list[UploadDetail]] = None,
|
||||
purpose: Literal["retrieval"],
|
||||
knowledge_id: Optional[str] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentObject:
|
||||
if not file and not upload_detail:
|
||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"file": file,
|
||||
"upload_detail": upload_detail,
|
||||
"purpose": purpose,
|
||||
"custom_separator": custom_separator,
|
||||
"knowledge_id": knowledge_id,
|
||||
"sentence_size": sentence_size,
|
||||
}
|
||||
)
|
||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
||||
if files:
|
||||
# It should be noted that the actual Content-Type header that will be
|
||||
# sent to the server will contain a `boundary` parameter, e.g.
|
||||
# multipart/form-data; boundary=---abc--
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
return self._post(
|
||||
"/files",
|
||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
||||
files=files,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=DocumentObject,
|
||||
)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
document_id: str,
|
||||
knowledge_type: str,
|
||||
*,
|
||||
custom_separator: Optional[list[str]] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
callback_url: Optional[str] = None,
|
||||
callback_header: Optional[dict[str, str]] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
|
||||
Args:
|
||||
document_id: 知识id
|
||||
knowledge_type: 知识类型:
|
||||
1:文章知识: 支持pdf,url,docx
|
||||
2.问答知识-文档: 支持pdf,url,docx
|
||||
3.问答知识-表格: 支持xlsx
|
||||
4.商品库-表格: 支持xlsx
|
||||
5.自定义: 支持pdf,url,docx
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
:param knowledge_type:
|
||||
:param document_id:
|
||||
:param timeout:
|
||||
:param extra_body:
|
||||
:param callback_header:
|
||||
:param sentence_size:
|
||||
:param extra_headers:
|
||||
:param callback_url:
|
||||
:param custom_separator:
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"id": document_id,
|
||||
"knowledge_type": knowledge_type,
|
||||
"custom_separator": custom_separator,
|
||||
"sentence_size": sentence_size,
|
||||
"callback_url": callback_url,
|
||||
"callback_header": callback_header,
|
||||
}
|
||||
)
|
||||
|
||||
return self._put(
|
||||
f"/document/{document_id}",
|
||||
body=maybe_transform(body, document_edit_params.DocumentEditParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
page: str | NotGiven = NOT_GIVEN,
|
||||
limit: str | NotGiven = NOT_GIVEN,
|
||||
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentPage:
|
||||
return self._get(
|
||||
"/files",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"knowledge_id": knowledge_id,
|
||||
"purpose": purpose,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"order": order,
|
||||
},
|
||||
document_list_params.DocumentListParams,
|
||||
),
|
||||
),
|
||||
cast_type=DocumentPage,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
|
||||
document_id: 知识id
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
return self._delete(
|
||||
f"/document/{document_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentData:
|
||||
"""
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
return self._get(
|
||||
f"/document/{document_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=DocumentData,
|
||||
)
|
|
@ -1,173 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
cached_property,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
|
||||
from ...types.knowledge.knowledge_list_resp import KnowledgePage
|
||||
from .document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Knowledge"]
|
||||
|
||||
|
||||
class Knowledge(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
@cached_property
|
||||
def document(self) -> Document:
|
||||
return Document(self._client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
embedding_id: int,
|
||||
name: str,
|
||||
*,
|
||||
customer_identifier: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
||||
bucket_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgeInfo:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"embedding_id": embedding_id,
|
||||
"name": name,
|
||||
"customer_identifier": customer_identifier,
|
||||
"description": description,
|
||||
"background": background,
|
||||
"icon": icon,
|
||||
"bucket_id": bucket_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/knowledge",
|
||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=KnowledgeInfo,
|
||||
)
|
||||
|
||||
def modify(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
embedding_id: int,
|
||||
*,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"id": knowledge_id,
|
||||
"embedding_id": embedding_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"background": background,
|
||||
"icon": icon,
|
||||
}
|
||||
)
|
||||
return self._put(
|
||||
f"/knowledge/{knowledge_id}",
|
||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
page: int | NotGiven = 1,
|
||||
size: int | NotGiven = 10,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgePage:
|
||||
return self._get(
|
||||
"/knowledge",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"page": page,
|
||||
"size": size,
|
||||
},
|
||||
knowledge_list_params.KnowledgeListParams,
|
||||
),
|
||||
),
|
||||
cast_type=KnowledgePage,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
knowledge_id: 知识库ID
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not knowledge_id:
|
||||
raise ValueError("Expected a non-empty value for `knowledge_id`")
|
||||
|
||||
return self._delete(
|
||||
f"/knowledge/{knowledge_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def used(
|
||||
self,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgeUsed:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get(
|
||||
"/knowledge/capacity",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=KnowledgeUsed,
|
||||
)
|
|
@ -1,3 +0,0 @@
|
|||
from .tools import Tools
|
||||
|
||||
__all__ = ["Tools"]
|
|
@ -1,65 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Tools"]
|
||||
|
||||
|
||||
class Tools(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def web_search(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], object, None],
|
||||
scope: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
location: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
recent_days: Optional[int] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> WebSearch | StreamResponse[WebSearchChunk]:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"scope": scope,
|
||||
"location": location,
|
||||
"recent_days": recent_days,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/tools",
|
||||
body=maybe_transform(body, tools_web_search_params.WebSearchParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=WebSearch,
|
||||
stream=stream or False,
|
||||
stream_cls=StreamResponse[WebSearchChunk],
|
||||
)
|
|
@ -1,7 +0,0 @@
|
|||
from .videos import (
|
||||
Videos,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Videos",
|
||||
]
|
|
@ -1,77 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
from ...types.video import VideoObject, video_create_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Videos"]
|
||||
|
||||
|
||||
class Videos(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def generations(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
prompt: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> VideoObject:
|
||||
if not model and not model:
|
||||
raise ValueError("At least one of `model` and `prompt` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"image_url": image_url,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/videos/generations",
|
||||
body=maybe_transform(body, video_create_params.VideoCreateParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=VideoObject,
|
||||
)
|
||||
|
||||
def retrieve_videos_result(
|
||||
self,
|
||||
id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> VideoObject:
|
||||
if not id:
|
||||
raise ValueError("At least one of `id` must be provided.")
|
||||
|
||||
return self._get(
|
||||
f"/async-result/{id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=VideoObject,
|
||||
)
|
|
@ -1,108 +0,0 @@
|
|||
from ._base_api import BaseAPI
|
||||
from ._base_compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
GenericModel,
|
||||
cached_property,
|
||||
field_get_default,
|
||||
get_args,
|
||||
get_model_config,
|
||||
get_model_fields,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
parse_obj,
|
||||
)
|
||||
from ._base_models import BaseModel, construct_type
|
||||
from ._base_type import (
|
||||
NOT_GIVEN,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
IncEx,
|
||||
ModelT,
|
||||
NotGiven,
|
||||
Query,
|
||||
)
|
||||
from ._constants import (
|
||||
ZHIPUAI_DEFAULT_LIMITS,
|
||||
ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
ZHIPUAI_DEFAULT_TIMEOUT,
|
||||
)
|
||||
from ._errors import (
|
||||
APIAuthenticationError,
|
||||
APIConnectionError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
||||
from ._files import is_file_content
|
||||
from ._http_client import HttpClient, make_request_options
|
||||
from ._sse_client import StreamResponse
|
||||
from ._utils import (
|
||||
deepcopy_minimal,
|
||||
drop_prefix_image_data,
|
||||
extract_files,
|
||||
is_given,
|
||||
is_list,
|
||||
is_mapping,
|
||||
maybe_transform,
|
||||
parse_date,
|
||||
parse_datetime,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"construct_type",
|
||||
"BaseAPI",
|
||||
"NOT_GIVEN",
|
||||
"Headers",
|
||||
"NotGiven",
|
||||
"Body",
|
||||
"IncEx",
|
||||
"ModelT",
|
||||
"Query",
|
||||
"FileTypes",
|
||||
"PYDANTIC_V2",
|
||||
"ConfigDict",
|
||||
"GenericModel",
|
||||
"get_args",
|
||||
"is_union",
|
||||
"parse_obj",
|
||||
"get_origin",
|
||||
"is_literal_type",
|
||||
"get_model_config",
|
||||
"get_model_fields",
|
||||
"field_get_default",
|
||||
"is_file_content",
|
||||
"ZhipuAIError",
|
||||
"APIStatusError",
|
||||
"APIRequestFailedError",
|
||||
"APIAuthenticationError",
|
||||
"APIReachLimitError",
|
||||
"APIInternalError",
|
||||
"APIServerFlowExceedError",
|
||||
"APIResponseError",
|
||||
"APIResponseValidationError",
|
||||
"APITimeoutError",
|
||||
"make_request_options",
|
||||
"HttpClient",
|
||||
"ZHIPUAI_DEFAULT_TIMEOUT",
|
||||
"ZHIPUAI_DEFAULT_MAX_RETRIES",
|
||||
"ZHIPUAI_DEFAULT_LIMITS",
|
||||
"is_list",
|
||||
"is_mapping",
|
||||
"parse_date",
|
||||
"parse_datetime",
|
||||
"is_given",
|
||||
"maybe_transform",
|
||||
"deepcopy_minimal",
|
||||
"extract_files",
|
||||
"StreamResponse",
|
||||
]
|
|
@ -1,19 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class BaseAPI:
|
||||
_client: ZhipuAI
|
||||
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
self._client = client
|
||||
self._delete = client.delete
|
||||
self._get = client.get
|
||||
self._post = client.post
|
||||
self._put = client.put
|
||||
self._patch = client.patch
|
||||
self._get_api_list = client.get_api_list
|
|
@ -1,209 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._base_type import StrBytesIntFloat
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
|
||||
|
||||
# --------------- Pydantic v2 compatibility ---------------
|
||||
|
||||
# Pyright incorrectly reports some of our functions as overriding a method when they don't
|
||||
# pyright: reportIncompatibleMethodOverride=false
|
||||
|
||||
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
||||
|
||||
# v1 re-exports
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def parse_date(value: date | StrBytesIntFloat) -> date: ...
|
||||
|
||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
|
||||
|
||||
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
|
||||
|
||||
def is_union(tp: type[Any] | None) -> bool: ...
|
||||
|
||||
def get_origin(t: type[Any]) -> type[Any] | None: ...
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool: ...
|
||||
|
||||
def is_typeddict(type_: type[Any]) -> bool: ...
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic.v1.typing import ( # noqa: I001
|
||||
get_args as get_args, # noqa: PLC0414
|
||||
is_union as is_union, # noqa: PLC0414
|
||||
get_origin as get_origin, # noqa: PLC0414
|
||||
is_typeddict as is_typeddict, # noqa: PLC0414
|
||||
is_literal_type as is_literal_type, # noqa: PLC0414
|
||||
)
|
||||
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
|
||||
else:
|
||||
from pydantic.typing import ( # noqa: I001
|
||||
get_args as get_args, # noqa: PLC0414
|
||||
is_union as is_union, # noqa: PLC0414
|
||||
get_origin as get_origin, # noqa: PLC0414
|
||||
is_typeddict as is_typeddict, # noqa: PLC0414
|
||||
is_literal_type as is_literal_type, # noqa: PLC0414
|
||||
)
|
||||
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
|
||||
|
||||
|
||||
# refactored config
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import ConfigDict
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import ConfigDict
|
||||
else:
|
||||
# TODO: provide an error message here?
|
||||
ConfigDict = None
|
||||
|
||||
|
||||
# renamed methods / properties
|
||||
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_validate(value)
|
||||
else:
|
||||
# pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
return cast(_ModelT, model.parse_obj(value))
|
||||
|
||||
|
||||
def field_is_required(field: FieldInfo) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return field.is_required()
|
||||
return field.required # type: ignore
|
||||
|
||||
|
||||
def field_get_default(field: FieldInfo) -> Any:
|
||||
value = field.get_default()
|
||||
if PYDANTIC_V2:
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
if value == PydanticUndefined:
|
||||
return None
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
def field_outer_type(field: FieldInfo) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return field.annotation
|
||||
return field.outer_type_ # type: ignore
|
||||
|
||||
|
||||
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_config
|
||||
return model.__config__ # type: ignore
|
||||
|
||||
|
||||
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_fields
|
||||
return model.__fields__ # type: ignore
|
||||
|
||||
|
||||
def model_copy(model: _ModelT) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_copy()
|
||||
return model.copy() # type: ignore
|
||||
|
||||
|
||||
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(indent=indent)
|
||||
return model.json(indent=indent) # type: ignore
|
||||
|
||||
|
||||
def model_dump(
|
||||
model: pydantic.BaseModel,
|
||||
*,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump(
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
)
|
||||
return cast(
|
||||
"dict[str, Any]",
|
||||
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_validate(data)
|
||||
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
|
||||
|
||||
|
||||
# generic models
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class GenericModel(pydantic.BaseModel): ...
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
# there no longer needs to be a distinction in v2 but
|
||||
# we still have to create our own subclass to avoid
|
||||
# inconsistent MRO ordering errors
|
||||
class GenericModel(pydantic.BaseModel): ...
|
||||
|
||||
else:
|
||||
import pydantic.generics
|
||||
|
||||
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
|
||||
|
||||
|
||||
# cached properties
|
||||
if TYPE_CHECKING:
|
||||
cached_property = property
|
||||
|
||||
# we define a separate type (copied from typeshed)
|
||||
# that represents that `cached_property` is `set`able
|
||||
# at runtime, which differs from `@property`.
|
||||
#
|
||||
# this is a separate type as editors likely special case
|
||||
# `@property` and we don't want to cause issues just to have
|
||||
# more helpful internal types.
|
||||
|
||||
class typed_cached_property(Generic[_T]): # noqa: N801
|
||||
func: Callable[[Any], _T]
|
||||
attrname: str | None
|
||||
|
||||
def __init__(self, func: Callable[[Any], _T]) -> None: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
|
||||
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __set_name__(self, owner: type[Any], name: str) -> None: ...
|
||||
|
||||
# __set__ is not defined at runtime, but @cached_property is designed to be settable
|
||||
def __set__(self, instance: object, value: _T) -> None: ...
|
||||
else:
|
||||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
from cached_property import cached_property
|
||||
|
||||
typed_cached_property = cached_property
|
|
@ -1,670 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast
|
||||
|
||||
import pydantic
|
||||
import pydantic.generics
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import (
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ._base_compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
field_get_default,
|
||||
get_args,
|
||||
get_model_config,
|
||||
get_model_fields,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
parse_obj,
|
||||
)
|
||||
from ._base_compat import (
|
||||
GenericModel as BaseGenericModel,
|
||||
)
|
||||
from ._base_type import (
|
||||
IncEx,
|
||||
ModelT,
|
||||
)
|
||||
from ._utils import (
|
||||
PropertyInfo,
|
||||
coerce_boolean,
|
||||
extract_type_arg,
|
||||
is_annotated_type,
|
||||
is_list,
|
||||
is_mapping,
|
||||
parse_date,
|
||||
parse_datetime,
|
||||
strip_annotated_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core.core_schema import ModelField
|
||||
|
||||
__all__ = ["BaseModel", "GenericModel"]
|
||||
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _ConfigProtocol(Protocol):
|
||||
allow_population_by_field_name: bool
|
||||
|
||||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
|
||||
)
|
||||
else:
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_fields_set(self) -> set[str]:
|
||||
# a forwards-compat shim for pydantic v2
|
||||
return self.__fields_set__ # type: ignore
|
||||
|
||||
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
||||
extra: Any = pydantic.Extra.allow # type: ignore
|
||||
|
||||
def to_dict(
|
||||
self,
|
||||
*,
|
||||
mode: Literal["json", "python"] = "python",
|
||||
use_api_names: bool = True,
|
||||
exclude_unset: bool = True,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
warnings: bool = True,
|
||||
) -> dict[str, object]:
|
||||
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
||||
|
||||
By default, fields that were not set by the API will not be included,
|
||||
and keys will match the API response, *not* the property names from the model.
|
||||
|
||||
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
||||
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
||||
|
||||
Args:
|
||||
mode:
|
||||
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
|
||||
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
|
||||
|
||||
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
||||
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
||||
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
|
||||
""" # noqa: E501
|
||||
return self.model_dump(
|
||||
mode=mode,
|
||||
by_alias=use_api_names,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def to_json(
|
||||
self,
|
||||
*,
|
||||
indent: int | None = 2,
|
||||
use_api_names: bool = True,
|
||||
exclude_unset: bool = True,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
warnings: bool = True,
|
||||
) -> str:
|
||||
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
|
||||
|
||||
By default, fields that were not set by the API will not be included,
|
||||
and keys will match the API response, *not* the property names from the model.
|
||||
|
||||
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
||||
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
||||
|
||||
Args:
|
||||
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
|
||||
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||
exclude_defaults: Whether to exclude fields that have the default value.
|
||||
exclude_none: Whether to exclude fields that have a value of `None`.
|
||||
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
|
||||
""" # noqa: E501
|
||||
return self.model_dump_json(
|
||||
indent=indent,
|
||||
by_alias=use_api_names,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
# mypy complains about an invalid self arg
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
|
||||
|
||||
# Override the 'construct' method in a way that supports recursive parsing without validation.
|
||||
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
|
||||
@classmethod
|
||||
@override
|
||||
def construct(
|
||||
cls: type[ModelT],
|
||||
_fields_set: set[str] | None = None,
|
||||
**values: object,
|
||||
) -> ModelT:
|
||||
m = cls.__new__(cls)
|
||||
fields_values: dict[str, object] = {}
|
||||
|
||||
config = get_model_config(cls)
|
||||
populate_by_name = (
|
||||
config.allow_population_by_field_name
|
||||
if isinstance(config, _ConfigProtocol)
|
||||
else config.get("populate_by_name")
|
||||
)
|
||||
|
||||
if _fields_set is None:
|
||||
_fields_set = set()
|
||||
|
||||
model_fields = get_model_fields(cls)
|
||||
for name, field in model_fields.items():
|
||||
key = field.alias
|
||||
if key is None or (key not in values and populate_by_name):
|
||||
key = name
|
||||
|
||||
if key in values:
|
||||
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
|
||||
_fields_set.add(name)
|
||||
else:
|
||||
fields_values[name] = field_get_default(field)
|
||||
|
||||
_extra = {}
|
||||
for key, value in values.items():
|
||||
if key not in model_fields:
|
||||
if PYDANTIC_V2:
|
||||
_extra[key] = value
|
||||
else:
|
||||
_fields_set.add(key)
|
||||
fields_values[key] = value
|
||||
|
||||
object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801
|
||||
|
||||
if PYDANTIC_V2:
|
||||
# these properties are copied from Pydantic's `model_construct()` method
|
||||
object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801
|
||||
object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801
|
||||
object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801
|
||||
else:
|
||||
# init_private_attributes() does not exist in v2
|
||||
m._init_private_attributes() # type: ignore
|
||||
|
||||
# copied from Pydantic v1's `construct()` method
|
||||
object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801
|
||||
|
||||
return m
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# type checkers incorrectly complain about this assignment
|
||||
# because the type signatures are technically different
|
||||
# although not in practice
|
||||
model_construct = construct
|
||||
|
||||
if not PYDANTIC_V2:
|
||||
# we define aliases for some of the new pydantic v2 methods so
|
||||
# that we can just document these methods without having to specify
|
||||
# a specific pydantic version as some users may not know which
|
||||
# pydantic version they are currently using
|
||||
|
||||
@override
|
||||
def model_dump(
|
||||
self,
|
||||
*,
|
||||
mode: Literal["json", "python"] | str = "python",
|
||||
include: IncEx = None,
|
||||
exclude: IncEx = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = True,
|
||||
context: dict[str, Any] | None = None,
|
||||
serialize_as_any: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
|
||||
|
||||
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
||||
|
||||
Args:
|
||||
mode: The mode in which `to_python` should run.
|
||||
If mode is 'json', the dictionary will only contain JSON serializable types.
|
||||
If mode is 'python', the dictionary may contain any Python objects.
|
||||
include: A list of fields to include in the output.
|
||||
exclude: A list of fields to exclude from the output.
|
||||
by_alias: Whether to use the field's alias in the dictionary key if defined.
|
||||
exclude_unset: Whether to exclude fields that are unset or None from the output.
|
||||
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
||||
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
||||
round_trip: Whether to enable serialization and deserialization round-trip support.
|
||||
warnings: Whether to log warnings when invalid fields are encountered.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the model.
|
||||
"""
|
||||
if mode != "python":
|
||||
raise ValueError("mode is only supported in Pydantic v2")
|
||||
if round_trip != False:
|
||||
raise ValueError("round_trip is only supported in Pydantic v2")
|
||||
if warnings != True:
|
||||
raise ValueError("warnings is only supported in Pydantic v2")
|
||||
if context is not None:
|
||||
raise ValueError("context is only supported in Pydantic v2")
|
||||
if serialize_as_any != False:
|
||||
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
||||
return super().dict( # pyright: ignore[reportDeprecated]
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
@override
|
||||
def model_dump_json(
|
||||
self,
|
||||
*,
|
||||
indent: int | None = None,
|
||||
include: IncEx = None,
|
||||
exclude: IncEx = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = True,
|
||||
context: dict[str, Any] | None = None,
|
||||
serialize_as_any: bool = False,
|
||||
) -> str:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
|
||||
|
||||
Generates a JSON representation of the model using Pydantic's `to_json` method.
|
||||
|
||||
Args:
|
||||
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
|
||||
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
|
||||
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
|
||||
by_alias: Whether to serialize using field aliases.
|
||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||
exclude_defaults: Whether to exclude fields that have the default value.
|
||||
exclude_none: Whether to exclude fields that have a value of `None`.
|
||||
round_trip: Whether to use serialization/deserialization between JSON and class instance.
|
||||
warnings: Whether to show any warnings that occurred during serialization.
|
||||
|
||||
Returns:
|
||||
A JSON string representation of the model.
|
||||
"""
|
||||
if round_trip != False:
|
||||
raise ValueError("round_trip is only supported in Pydantic v2")
|
||||
if warnings != True:
|
||||
raise ValueError("warnings is only supported in Pydantic v2")
|
||||
if context is not None:
|
||||
raise ValueError("context is only supported in Pydantic v2")
|
||||
if serialize_as_any != False:
|
||||
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
||||
return super().json( # type: ignore[reportDeprecated]
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
|
||||
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
|
||||
if value is None:
|
||||
return field_get_default(field)
|
||||
|
||||
if PYDANTIC_V2:
|
||||
type_ = field.annotation
|
||||
else:
|
||||
type_ = cast(type, field.outer_type_) # type: ignore
|
||||
|
||||
if type_ is None:
|
||||
raise RuntimeError(f"Unexpected field type is None for {key}")
|
||||
|
||||
return construct_type(value=value, type_=type_)
|
||||
|
||||
|
||||
def is_basemodel(type_: type) -> bool:
|
||||
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
|
||||
if is_union(type_):
|
||||
return any(is_basemodel(variant) for variant in get_args(type_))
|
||||
|
||||
return is_basemodel_type(type_)
|
||||
|
||||
|
||||
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
|
||||
origin = get_origin(type_) or type_
|
||||
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
||||
|
||||
|
||||
def build(
|
||||
base_model_cls: Callable[P, _BaseModelT],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> _BaseModelT:
|
||||
"""Construct a BaseModel class without validation.
|
||||
|
||||
This is useful for cases where you need to instantiate a `BaseModel`
|
||||
from an API response as this provides type-safe params which isn't supported
|
||||
by helpers like `construct_type()`.
|
||||
|
||||
```py
|
||||
build(MyModel, my_field_a="foo", my_field_b=123)
|
||||
```
|
||||
"""
|
||||
if args:
|
||||
raise TypeError(
|
||||
"Received positional arguments which are not supported; Keyword arguments must be used instead",
|
||||
)
|
||||
|
||||
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
|
||||
|
||||
|
||||
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
|
||||
"""Loose coercion to the expected type with construction of nested values.
|
||||
|
||||
Note: the returned value from this function is not guaranteed to match the
|
||||
given type.
|
||||
"""
|
||||
return cast(_T, construct_type(value=value, type_=type_))
|
||||
|
||||
|
||||
def construct_type(*, value: object, type_: type) -> object:
|
||||
"""Loose coercion to the expected type with construction of nested values.
|
||||
|
||||
If the given value does not match the expected type then it is returned as-is.
|
||||
"""
|
||||
# we allow `object` as the input type because otherwise, passing things like
|
||||
# `Literal['value']` will be reported as a type error by type checkers
|
||||
type_ = cast("type[object]", type_)
|
||||
|
||||
# unwrap `Annotated[T, ...]` -> `T`
|
||||
if is_annotated_type(type_):
|
||||
meta: tuple[Any, ...] = get_args(type_)[1:]
|
||||
type_ = extract_type_arg(type_, 0)
|
||||
else:
|
||||
meta = ()
|
||||
# we need to use the origin class for any types that are subscripted generics
|
||||
# e.g. Dict[str, object]
|
||||
origin = get_origin(type_) or type_
|
||||
args = get_args(type_)
|
||||
|
||||
if is_union(origin):
|
||||
try:
|
||||
return validate_type(type_=cast("type[object]", type_), value=value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if the type is a discriminated union then we want to construct the right variant
|
||||
# in the union, even if the data doesn't match exactly, otherwise we'd break code
|
||||
# that relies on the constructed class types, e.g.
|
||||
#
|
||||
# class FooType:
|
||||
# kind: Literal['foo']
|
||||
# value: str
|
||||
#
|
||||
# class BarType:
|
||||
# kind: Literal['bar']
|
||||
# value: int
|
||||
#
|
||||
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
||||
# we'd end up constructing `FooType` when it should be `BarType`.
|
||||
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
||||
if discriminator and is_mapping(value):
|
||||
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
||||
if variant_value and isinstance(variant_value, str):
|
||||
variant_type = discriminator.mapping.get(variant_value)
|
||||
if variant_type:
|
||||
return construct_type(type_=variant_type, value=value)
|
||||
|
||||
# if the data is not valid, use the first variant that doesn't fail while deserializing
|
||||
for variant in args:
|
||||
try:
|
||||
return construct_type(value=value, type_=variant)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
|
||||
if origin == dict:
|
||||
if not is_mapping(value):
|
||||
return value
|
||||
|
||||
_, items_type = get_args(type_) # Dict[_, items_type]
|
||||
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
||||
|
||||
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
|
||||
if is_list(value):
|
||||
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
|
||||
|
||||
if is_mapping(value):
|
||||
if issubclass(type_, BaseModel):
|
||||
return type_.construct(**value) # type: ignore[arg-type]
|
||||
|
||||
return cast(Any, type_).construct(**value)
|
||||
|
||||
if origin == list:
|
||||
if not is_list(value):
|
||||
return value
|
||||
|
||||
inner_type = args[0] # List[inner_type]
|
||||
return [construct_type(value=entry, type_=inner_type) for entry in value]
|
||||
|
||||
if origin == float:
|
||||
if isinstance(value, int):
|
||||
coerced = float(value)
|
||||
if coerced != value:
|
||||
return value
|
||||
return coerced
|
||||
|
||||
return value
|
||||
|
||||
if type_ == datetime:
|
||||
try:
|
||||
return parse_datetime(value) # type: ignore
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if type_ == date:
|
||||
try:
|
||||
return parse_date(value) # type: ignore
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedDiscriminatorType(Protocol):
|
||||
__discriminator__: DiscriminatorDetails
|
||||
|
||||
|
||||
class DiscriminatorDetails:
|
||||
field_name: str
|
||||
"""The name of the discriminator field in the variant class, e.g.
|
||||
|
||||
```py
|
||||
class Foo(BaseModel):
|
||||
type: Literal['foo']
|
||||
```
|
||||
|
||||
Will result in field_name='type'
|
||||
"""
|
||||
|
||||
field_alias_from: str | None
|
||||
"""The name of the discriminator field in the API response, e.g.
|
||||
|
||||
```py
|
||||
class Foo(BaseModel):
|
||||
type: Literal['foo'] = Field(alias='type_from_api')
|
||||
```
|
||||
|
||||
Will result in field_alias_from='type_from_api'
|
||||
"""
|
||||
|
||||
mapping: dict[str, type]
|
||||
"""Mapping of discriminator value to variant type, e.g.
|
||||
|
||||
{'foo': FooVariant, 'bar': BarVariant}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mapping: dict[str, type],
|
||||
discriminator_field: str,
|
||||
discriminator_alias: str | None,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
self.field_name = discriminator_field
|
||||
self.field_alias_from = discriminator_alias
|
||||
|
||||
|
||||
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
|
||||
if isinstance(union, CachedDiscriminatorType):
|
||||
return union.__discriminator__
|
||||
|
||||
discriminator_field_name: str | None = None
|
||||
|
||||
for annotation in meta_annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
||||
discriminator_field_name = annotation.discriminator
|
||||
break
|
||||
|
||||
if not discriminator_field_name:
|
||||
return None
|
||||
|
||||
mapping: dict[str, type] = {}
|
||||
discriminator_alias: str | None = None
|
||||
|
||||
for variant in get_args(union):
|
||||
variant = strip_annotated_type(variant)
|
||||
if is_basemodel_type(variant):
|
||||
if PYDANTIC_V2:
|
||||
field = _extract_field_schema_pv2(variant, discriminator_field_name)
|
||||
if not field:
|
||||
continue
|
||||
|
||||
# Note: if one variant defines an alias then they all should
|
||||
discriminator_alias = field.get("serialization_alias")
|
||||
|
||||
field_schema = field["schema"]
|
||||
|
||||
if field_schema["type"] == "literal":
|
||||
for entry in cast("LiteralSchema", field_schema)["expected"]:
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
else:
|
||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
if not field_info:
|
||||
continue
|
||||
|
||||
# Note: if one variant defines an alias then they all should
|
||||
discriminator_alias = field_info.alias
|
||||
|
||||
if field_info.annotation and is_literal_type(field_info.annotation):
|
||||
for entry in get_args(field_info.annotation):
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
details = DiscriminatorDetails(
|
||||
mapping=mapping,
|
||||
discriminator_field=discriminator_field_name,
|
||||
discriminator_alias=discriminator_alias,
|
||||
)
|
||||
cast(CachedDiscriminatorType, union).__discriminator__ = details
|
||||
return details
|
||||
|
||||
|
||||
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
||||
schema = model.__pydantic_core_schema__
|
||||
if schema["type"] != "model":
|
||||
return None
|
||||
|
||||
fields_schema = schema["schema"]
|
||||
if fields_schema["type"] != "model-fields":
|
||||
return None
|
||||
|
||||
fields_schema = cast("ModelFieldsSchema", fields_schema)
|
||||
|
||||
field = fields_schema["fields"].get(field_name)
|
||||
if not field:
|
||||
return None
|
||||
|
||||
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
|
||||
|
||||
|
||||
def validate_type(*, type_: type[_T], value: object) -> _T:
|
||||
"""Strict validation that the given value matches the expected type"""
|
||||
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
|
||||
return cast(_T, parse_obj(type_, value))
|
||||
|
||||
return cast(_T, _validate_non_model_type(type_=type_, value=value))
|
||||
|
||||
|
||||
# Subclassing here confuses type checkers, so we treat this class as non-inheriting.
|
||||
if TYPE_CHECKING:
|
||||
GenericModel = BaseModel
|
||||
else:
|
||||
|
||||
class GenericModel(BaseGenericModel, BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||
return TypeAdapter(type_).validate_python(value)
|
||||
|
||||
elif not TYPE_CHECKING:
|
||||
|
||||
class TypeAdapter(Generic[_T]):
|
||||
"""Used as a placeholder to easily convert runtime types to a Pydantic format
|
||||
to provide validation.
|
||||
|
||||
For example:
|
||||
```py
|
||||
validated = RootModel[int](__root__="5").__root__
|
||||
# validated: 5
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, type_: type[_T]):
|
||||
self.type_ = type_
|
||||
|
||||
def validate_python(self, value: Any) -> _T:
|
||||
if not isinstance(value, self.type_):
|
||||
raise ValueError(f"Invalid type: {value} is not of type {self.type_}")
|
||||
return value
|
||||
|
||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||
return TypeAdapter(type_).validate_python(value)
|
|
@ -1,170 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from os import PathLike
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
from httpx import Response
|
||||
from typing_extensions import Protocol, TypedDict, override, runtime_checkable
|
||||
|
||||
Query = Mapping[str, object]
|
||||
Body = object
|
||||
AnyMapping = Mapping[str, object]
|
||||
PrimitiveData = Union[str, int, float, bool, None]
|
||||
Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"]
|
||||
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
||||
_T = TypeVar("_T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
NoneType: type[None]
|
||||
else:
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
# Sentinel class used until PEP 0661 is accepted
|
||||
class NotGiven:
|
||||
"""
|
||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||
from those passed in with the value None (which may have different behavior).
|
||||
|
||||
For example:
|
||||
|
||||
```py
|
||||
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
|
||||
|
||||
get(timeout=1) # 1s timeout
|
||||
get(timeout=None) # No timeout
|
||||
get() # Default timeout behavior, which may not be statically known at the method definition.
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_GIVEN"
|
||||
|
||||
|
||||
NotGivenOr = Union[_T, NotGiven]
|
||||
NOT_GIVEN = NotGiven()
|
||||
|
||||
|
||||
class Omit:
|
||||
"""In certain situations you need to be able to represent a case where a default value has
|
||||
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
||||
|
||||
```py
|
||||
# as the default `Content-Type` header is `application/json` that will be sent
|
||||
client.post('/upload/files', files={'file': b'my raw file content'})
|
||||
|
||||
# you can't explicitly override the header as it has to be dynamically generated
|
||||
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
||||
client.post(..., headers={'Content-Type': 'multipart/form-data'})
|
||||
|
||||
# instead you can remove the default `application/json` header by passing Omit
|
||||
client.post(..., headers={'Content-Type': Omit()})
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelBuilderProtocol(Protocol):
|
||||
@classmethod
|
||||
def build(
|
||||
cls: type[_T],
|
||||
*,
|
||||
response: Response,
|
||||
data: object,
|
||||
) -> _T: ...
|
||||
|
||||
|
||||
Headers = Mapping[str, Union[str, Omit]]
|
||||
|
||||
|
||||
class HeadersLikeProtocol(Protocol):
|
||||
def get(self, __key: str) -> str | None: ...
|
||||
|
||||
|
||||
HeadersLike = Union[Headers, HeadersLikeProtocol]
|
||||
|
||||
ResponseT = TypeVar(
|
||||
"ResponseT",
|
||||
bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501
|
||||
)
|
||||
|
||||
StrBytesIntFloat = Union[str, bytes, int, float]
|
||||
|
||||
# Note: copied from Pydantic
|
||||
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
|
||||
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
|
||||
|
||||
PostParser = Callable[[Any], Any]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class InheritsGeneric(Protocol):
|
||||
"""Represents a type that has inherited from `Generic`
|
||||
|
||||
The `__orig_bases__` property can be used to determine the resolved
|
||||
type variable for a given base class.
|
||||
"""
|
||||
|
||||
__orig_bases__: tuple[_GenericAlias]
|
||||
|
||||
|
||||
class _GenericAlias(Protocol):
|
||||
__origin__: type[object]
|
||||
|
||||
|
||||
class HttpxSendArgs(TypedDict, total=False):
|
||||
auth: httpx.Auth
|
||||
|
||||
|
||||
# for user input files
|
||||
if TYPE_CHECKING:
|
||||
Base64FileInput = Union[IO[bytes], PathLike[str]]
|
||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||
else:
|
||||
Base64FileInput = Union[IO[bytes], PathLike]
|
||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||
|
||||
FileTypes = Union[
|
||||
# file (or bytes)
|
||||
FileContent,
|
||||
# (filename, file (or bytes))
|
||||
tuple[Optional[str], FileContent],
|
||||
# (filename, file (or bytes), content_type)
|
||||
tuple[Optional[str], FileContent, Optional[str]],
|
||||
# (filename, file (or bytes), content_type, headers)
|
||||
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
|
||||
]
|
||||
RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
|
||||
|
||||
# duplicate of the above but without our custom file support
|
||||
HttpxFileContent = Union[bytes, IO[bytes]]
|
||||
HttpxFileTypes = Union[
|
||||
# file (or bytes)
|
||||
HttpxFileContent,
|
||||
# (filename, file (or bytes))
|
||||
tuple[Optional[str], HttpxFileContent],
|
||||
# (filename, file (or bytes), content_type)
|
||||
tuple[Optional[str], HttpxFileContent, Optional[str]],
|
||||
# (filename, file (or bytes), content_type, headers)
|
||||
tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
||||
]
|
||||
|
||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]]
|
|
@ -1,12 +0,0 @@
|
|||
import httpx
|
||||
|
||||
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
|
||||
# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0`
|
||||
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
||||
# 通过 `retry` 参数控制重试次数,默认为3次
|
||||
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
|
||||
# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10`
|
||||
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
|
||||
|
||||
INITIAL_RETRY_DELAY = 0.5
|
||||
MAX_RETRY_DELAY = 8.0
|
|
@ -1,86 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
__all__ = [
|
||||
"ZhipuAIError",
|
||||
"APIStatusError",
|
||||
"APIRequestFailedError",
|
||||
"APIAuthenticationError",
|
||||
"APIReachLimitError",
|
||||
"APIInternalError",
|
||||
"APIServerFlowExceedError",
|
||||
"APIResponseError",
|
||||
"APIResponseValidationError",
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
]
|
||||
|
||||
|
||||
class ZhipuAIError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class APIStatusError(ZhipuAIError):
|
||||
response: httpx.Response
|
||||
status_code: int
|
||||
|
||||
def __init__(self, message: str, *, response: httpx.Response) -> None:
|
||||
super().__init__(message)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APIRequestFailedError(APIStatusError): ...
|
||||
|
||||
|
||||
class APIAuthenticationError(APIStatusError): ...
|
||||
|
||||
|
||||
class APIReachLimitError(APIStatusError): ...
|
||||
|
||||
|
||||
class APIInternalError(APIStatusError): ...
|
||||
|
||||
|
||||
class APIServerFlowExceedError(APIStatusError): ...
|
||||
|
||||
|
||||
class APIResponseError(ZhipuAIError):
|
||||
message: str
|
||||
request: httpx.Request
|
||||
json_data: object
|
||||
|
||||
def __init__(self, message: str, request: httpx.Request, json_data: object):
|
||||
self.message = message
|
||||
self.request = request
|
||||
self.json_data = json_data
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class APIResponseValidationError(APIResponseError):
|
||||
status_code: int
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None:
|
||||
super().__init__(
|
||||
message=message or "Data returned by API invalid for expected schema.",
|
||||
request=response.request,
|
||||
json_data=json_data,
|
||||
)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APIConnectionError(APIResponseError):
|
||||
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
|
||||
super().__init__(message, request, json_data=None)
|
||||
|
||||
|
||||
class APITimeoutError(APIConnectionError):
|
||||
def __init__(self, request: httpx.Request) -> None:
|
||||
super().__init__(message="Request timed out.", request=request)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user