Merge branch 'main' into feat/new-login

* main: (35 commits)
  fix https://github.com/langgenius/dify/issues/9409 (#9433)
  update dataset clean rule (#9426)
  add clean 7 days datasets (#9424)
  fix: resolve overlap issue with API Extension selector and modal (#9407)
  refactor: update the default values of top-k parameter in vdb to be consistent (#9367)
  fix: incorrect webapp image displayed (#9401)
  Fix/economical knowledge retrieval (#9396)
  feat: add timezone conversion for time tool (#9393)
  fix: Deprecated gemma2-9b model in Fireworks AI Provider (#9373)
  feat: storybook (#9324)
  fix: use gpt-4o-mini for validating credentials (#9387)
  feat: Enable baiduvector intergration test (#9369)
  fix: remove the stream option of zhipu and gemini (#9319)
  fix: add missing vikingdb param in docker .env.example (#9334)
  feat: add minimax abab6.5t support (#9365)
  fix: (#9336 followup) skip poetry preperation in style workflow when no change in api folder (#9362)
  feat: add glm-4-flashx, deprecated chatglm_turbo (#9357)
  fix: Azure OpenAI o1 max_completion_token and get_num_token_from_messages error (#9326)
  fix: In the output, the order of 'ta' is sometimes reversed as 'at'. #8015 (#8791)
  refactor: Add an enumeration type and use the factory pattern to obtain the corresponding class (#9356)
  ...
This commit is contained in:
Joe 2024-10-17 11:33:33 +08:00
commit abe6a1bb99
250 changed files with 7602 additions and 9658 deletions

View File

@ -27,18 +27,17 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache-dependency-path: |
api/pyproject.toml
api/poetry.lock
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Check Poetry lockfile
run: |
poetry check -C api --lock

View File

@ -23,18 +23,17 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache-dependency-path: |
api/pyproject.toml
api/poetry.lock
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Install dependencies
run: poetry install -C api

View File

@ -24,15 +24,16 @@ jobs:
with:
files: api/**
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python
uses: actions/setup-python@v5
if: steps.changed-files.outputs.any_changed == 'true'
with:
python-version: '3.10'
- name: Install Poetry
if: steps.changed-files.outputs.any_changed == 'true'
uses: abatilo/actions-poetry@v3
- name: Python dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint

View File

@ -85,3 +85,4 @@
cd ../
poetry run -C api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -531,11 +531,16 @@ class DataSetConfig(BaseSettings):
Configuration for dataset management
"""
CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations",
PLAN_SANDBOX_CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations - plan: sandbox",
default=30,
)
PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations - plan: pro and team",
default=7,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description="Enable or disable dataset operator functionality",
default=False,

View File

@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
default=None,
)
ORACLE_PORT: Optional[PositiveInt] = Field(
ORACLE_PORT: PositiveInt = Field(
description="Port number on which the Oracle database server is listening (default is 1521)",
default=1521,
)

View File

@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
default=None,
)
PGVECTOR_PORT: Optional[PositiveInt] = Field(
PGVECTOR_PORT: PositiveInt = Field(
description="Port number on which the PostgreSQL server is listening (default is 5433)",
default=5433,
)

View File

@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
default=None,
)
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
PGVECTO_RS_PORT: PositiveInt = Field(
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
default=5431,
)

View File

@ -11,27 +11,39 @@ class VikingDBConfig(BaseModel):
"""
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
description="The Access Key provided by Volcengine VikingDB for API authentication."
"Refer to the following documentation for details on obtaining credentials:"
"https://www.volcengine.com/docs/6291/65568",
default=None,
)
VIKINGDB_SECRET_KEY: Optional[str] = Field(
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
default=None,
)
VIKINGDB_REGION: Optional[str] = Field(
default="cn-shanghai",
VIKINGDB_REGION: str = Field(
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
default="cn-shanghai",
)
VIKINGDB_HOST: Optional[str] = Field(
default="api-vikingdb.mlp.cn-shanghai.volces.com",
VIKINGDB_HOST: str = Field(
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
'api-vikingdb.mlp.cn-shanghai.volces.com')",
default="api-vikingdb.mlp.cn-shanghai.volces.com",
)
VIKINGDB_SCHEME: Optional[str] = Field(
default="http",
VIKINGDB_SCHEME: str = Field(
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
default="http",
)
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
default=30, description="The connection timeout of the Volcengine VikingDB service."
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
description="The connection timeout of the Volcengine VikingDB service.",
default=30,
)
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
default=30, description="The socket timeout of the Volcengine VikingDB service."
VIKINGDB_SOCKET_TIMEOUT: int = Field(
description="The socket timeout of the Volcengine VikingDB service.",
default=30,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.9.1",
default="0.9.2",
)
COMMIT_SHA: str = Field(

View File

@ -1,88 +1,24 @@
import logging
from flask_restful import Resource
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
class HitTestingApi(Resource):
class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
return self.perform_hit_testing(dataset, args)
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

View File

@ -0,0 +1,85 @@
import logging
from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return dataset
@staticmethod
def hit_testing_args_check(args):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
return parser.parse_args()
@staticmethod
def perform_hit_testing(dataset, args):
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))

View File

@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)
from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment
from .dataset import dataset, document, hit_testing, segment

View File

@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from constants import UUID_NIL
from controllers.service_api import api
from controllers.service_api.app.error import (
AppUnavailableError,
@ -108,7 +107,6 @@ class ChatApi(Resource):
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
args = parser.parse_args()

View File

@ -0,0 +1,17 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

View File

@ -62,6 +62,8 @@ class CotAgentOutputParser:
thought_str = "thought:"
thought_idx = 0
last_character = ""
for response in llm_response:
if response.delta.usage:
usage_dict["usage"] = response.delta.usage
@ -74,35 +76,38 @@ class CotAgentOutputParser:
while index < len(response):
steps = 1
delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else ""
yield_delta = False
if delta == "`":
last_character = delta
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
last_character = delta
yield code_block_cache
code_block_cache = ""
else:
last_character = delta
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in {"\n", " ", ""}:
yield_delta = True
else:
last_character = delta
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_idx = 0
index += steps
yield delta
continue
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_idx = 0
index += steps
continue
elif delta.lower() == action_str[action_idx] and action_idx > 0:
last_character = delta
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
@ -112,24 +117,25 @@ class CotAgentOutputParser:
continue
else:
if action_cache:
last_character = delta
yield action_cache
action_cache = ""
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in {"\n", " ", ""}:
yield_delta = True
else:
last_character = delta
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_idx = 0
index += steps
yield delta
continue
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_idx = 0
index += steps
continue
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
last_character = delta
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
@ -139,12 +145,20 @@ class CotAgentOutputParser:
continue
else:
if thought_cache:
last_character = delta
yield thought_cache
thought_cache = ""
thought_idx = 0
if yield_delta:
index += steps
last_character = delta
yield delta
continue
if code_block_delimiter_count == 3:
if in_code_block:
last_character = delta
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ""
@ -156,8 +170,10 @@ class CotAgentOutputParser:
if delta == "{":
json_quote_count += 1
in_json = True
last_character = delta
json_cache += delta
elif delta == "}":
last_character = delta
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
@ -168,16 +184,19 @@ class CotAgentOutputParser:
continue
else:
if in_json:
last_character = delta
json_cache += delta
if got_json:
got_json = False
last_character = delta
yield parse_action(json_cache)
json_cache = ""
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
last_character = delta
yield delta.replace("`", "")
index += steps

View File

@ -10,6 +10,7 @@ from flask import Flask, current_app
from pydantic import ValidationError
import contexts
from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@ -122,7 +123,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
@ -127,7 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
@ -128,7 +129,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -2,8 +2,9 @@ from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileVar
@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
model_config = ConfigDict(protected_namespaces=())
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
class ConversationAppGenerateEntity(AppGenerateEntity):
"""
Base entity for conversation-based app generation.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
"For service API, we need to ensure its forward compatibility, "
"so passing in the parent_message_id as request arg is not supported for now. "
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
),
)
@field_validator("parent_message_id")
@classmethod
def validate_parent_message_id(cls, v, info: ValidationInfo):
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
raise ValueError("parent_message_id should be UUID_NIL for service API")
return v
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
"""
Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None
pass
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
pass
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
"""
Agent Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None
pass
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
Advanced Chat Application Generate Entity.
"""
@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
# app config
app_config: WorkflowUIBasedAppConfig
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None
workflow_run_id: Optional[str] = None
query: str

View File

@ -1098,6 +1098,14 @@ LLM_BASE_MODELS = [
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
@ -1135,6 +1143,14 @@ LLM_BASE_MODELS = [
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),

View File

@ -119,7 +119,15 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if model.startswith("o1"):
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=1,
max_completion_tokens=20,
stream=False,
)
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],

View File

@ -52,6 +52,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.00025'
output: '0.00125'

View File

@ -51,6 +51,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'

View File

@ -51,6 +51,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'

View File

@ -52,6 +52,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.00025'
output: '0.00125'

View File

@ -52,6 +52,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.015'
output: '0.075'

View File

@ -51,6 +51,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'

View File

@ -51,6 +51,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'

View File

@ -18,6 +18,7 @@ supported_model_types:
- text-embedding
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: fireworks_api_key
@ -28,3 +29,75 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model URL
zh_Hans: 模型URL
placeholder:
en_US: Enter your Model URL
zh_Hans: 输入模型URL
credential_form_schemas:
- variable: model_label_zh_Hanns
label:
zh_Hans: 模型中文名称
en_US: The zh_Hans of Model
required: true
type: text-input
placeholder:
zh_Hans: 在此输入您的模型中文名称
en_US: Enter your zh_Hans of Model
- variable: model_label_en_US
label:
zh_Hans: 模型英文名称
en_US: The en_US of Model
required: true
type: text-input
placeholder:
zh_Hans: 在此输入您的模型英文名称
en_US: Enter your en_US of Model
- variable: fireworks_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '4096'
type: text-input
show_on:
- variable: __model_type
value: llm
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- value: function_call
label:
en_US: Support
zh_Hans: 支持
show_on:
- variable: __model_type
value: llm

View File

@ -43,3 +43,4 @@ pricing:
output: '0.2'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
return AIModelEntity(
model=model,
label=I18nObject(
en_US=credentials.get("model_label_en_US", model),
zh_Hans=credentials.get("model_label_zh_Hanns", model),
),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "function_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 4096)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="top_k",
use_template="top_k",
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
type=ParameterType.FLOAT,
),
],
)

View File

@ -0,0 +1,46 @@
model: accounts/fireworks/models/qwen2p5-72b-instruct
label:
zh_Hans: Qwen2.5 72B Instruct
en_US: Qwen2.5 72B Instruct
model_type: llm
features:
- agent-thought
- tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
- name: max_tokens
use_template: max_tokens
- name: context_length_exceeded_behavior
default: None
label:
zh_Hans: 上下文长度超出行为
en_US: Context Length Exceeded Behavior
help:
zh_Hans: 上下文长度超出行为
en_US: Context Length Exceeded Behavior
type: string
options:
- None
- truncate
- error
- name: response_format
use_template: response_format
pricing:
input: '0.9'
output: '0.9'
unit: '0.000001'
currency: USD

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -32,15 +32,6 @@ parameter_rules:
max: 8192
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -27,15 +27,6 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -31,15 +31,6 @@ parameter_rules:
max: 2048
- name: response_format
use_template: response_format
- name: stream
label:
zh_Hans: 流式输出
en_US: Stream
type: boolean
help:
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
default: false
pricing:
input: '0.00'
output: '0.00'

View File

@ -0,0 +1,44 @@
model: abab6.5t-chat
label:
en_US: Abab6.5t-Chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.01
max: 1
default: 0.9
- name: top_p
use_template: top_p
min: 0.01
max: 1
default: 0.95
- name: max_tokens
use_template: max_tokens
required: true
default: 3072
min: 1
max: 8192
- name: mask_sensitive_info
type: boolean
default: true
label:
zh_Hans: 隐私保护
en_US: Moderate
help:
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码目前包括但不限于邮箱、域名、链接、证件号、家庭住址等默认true即开启打码
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '0.005'
output: '0.005'
unit: '0.001'
currency: RMB

View File

@ -61,7 +61,8 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
url = f"{self.api_base}?GroupId={group_id}"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
data = {"model": "embo-01", "texts": texts, "type": "db"}
embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query"
data = {"model": "embo-01", "texts": texts, "type": embedding_type}
try:
response = post(url, headers=headers, data=dumps(data))

View File

@ -19,9 +19,9 @@ class OpenAIProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `gpt-3.5-turbo` model for validate,
# Use `gpt-4o-mini` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials)
model_instance.validate_credentials(model="gpt-4o-mini", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:

View File

@ -28,15 +28,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: return_type
label:
zh_Hans: 回复类型
@ -49,3 +40,4 @@ parameter_rules:
options:
- text
- json_string
deprecated: true

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -0,0 +1,53 @@
model: glm-4-flashx
label:
en_US: glm-4-flashx
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.7
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 4095
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
pricing:
input: '0'
output: '0'
unit: '0.001'
currency: RMB

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -35,15 +35,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -32,15 +32,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -30,15 +30,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -30,15 +30,6 @@ parameter_rules:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: stream
label:
zh_Hans: 流处理
en_US: Event Stream
type: boolean
help:
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true模型将通过标准 Event Stream 逐块返回模型生成内容。Event Stream 结束时会返回一条data[DONE]消息。注意在模型流式输出生成内容的过程中我们会分批对模型生成内容进行检测当检测到违法及不良信息时API会返回错误码1301。开发者识别到错误码1301应及时采取清屏、重启对话等措施删除生成内容并确保不将含有违法及不良信息的内容传递给模型继续生成避免其造成负面影响。
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data[DONE] message will be sent at the end of the Event Stream.NoteDuring the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
default: false
- name: max_tokens
use_template: max_tokens
default: 1024

View File

@ -1,6 +1,10 @@
from collections.abc import Generator
from typing import Optional, Union
from zhipuai import ZhipuAI
from zhipuai.types.chat.chat_completion import Completion
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -16,9 +20,6 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.

View File

@ -1,13 +1,14 @@
import time
from typing import Optional
from zhipuai import ZhipuAI
from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):

View File

@ -1,15 +0,0 @@
from .__version__ import __version__
from ._client import ZhipuAI
from .core import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,
APIResponseError,
APIResponseValidationError,
APIServerFlowExceedError,
APIStatusError,
APITimeoutError,
ZhipuAIError,
)

View File

@ -1 +0,0 @@
__version__ = "v2.1.0"

View File

@ -1,82 +0,0 @@
from __future__ import annotations
import os
from collections.abc import Mapping
from typing import Union
import httpx
from httpx import Timeout
from typing_extensions import override
from . import api_resource
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
class ZhipuAI(HttpClient):
chat: api_resource.chat.Chat
api_key: str
_disable_token_cache: bool = True
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
disable_token_cache: bool = True,
_strict_response_validation: bool = False,
) -> None:
if api_key is None:
api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise ZhipuAIError("未提供api_key请通过参数或环境变量提供")
self.api_key = api_key
self._disable_token_cache = disable_token_cache
if base_url is None:
base_url = os.environ.get("ZHIPUAI_BASE_URL")
if base_url is None:
base_url = "https://open.bigmodel.cn/api/paas/v4"
from .__version__ import __version__
super().__init__(
version=__version__,
base_url=base_url,
max_retries=max_retries,
timeout=timeout,
custom_httpx_client=http_client,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self.chat = api_resource.chat.Chat(self)
self.images = api_resource.images.Images(self)
self.embeddings = api_resource.embeddings.Embeddings(self)
self.files = api_resource.files.Files(self)
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
self.batches = api_resource.Batches(self)
self.knowledge = api_resource.Knowledge(self)
self.tools = api_resource.Tools(self)
self.videos = api_resource.Videos(self)
self.assistant = api_resource.Assistant(self)
@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if self._disable_token_cache:
return {"Authorization": f"Bearer {api_key}"}
else:
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
def __del__(self) -> None:
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
# if the '__init__' method raised an error, self would not have client attr
return
if self._has_custom_http_client:
return
self.close()

View File

@ -1,34 +0,0 @@
from .assistant import (
Assistant,
)
from .batches import Batches
from .chat import (
AsyncCompletions,
Chat,
Completions,
)
from .embeddings import Embeddings
from .files import Files, FilesWithRawResponse
from .fine_tuning import FineTuning
from .images import Images
from .knowledge import Knowledge
from .tools import Tools
from .videos import (
Videos,
)
__all__ = [
"Videos",
"AsyncCompletions",
"Chat",
"Completions",
"Images",
"Embeddings",
"Files",
"FilesWithRawResponse",
"FineTuning",
"Batches",
"Knowledge",
"Tools",
"Assistant",
]

View File

@ -1,3 +0,0 @@
from .assistant import Assistant
__all__ = ["Assistant"]

View File

@ -1,122 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.assistant import AssistantCompletion
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
from ...types.assistant.assistant_support_resp import AssistantSupportResp
if TYPE_CHECKING:
from ..._client import ZhipuAI
from ...types.assistant import assistant_conversation_params, assistant_create_params
__all__ = ["Assistant"]
class Assistant(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def conversation(
self,
assistant_id: str,
model: str,
messages: list[assistant_create_params.ConversationMessage],
*,
stream: bool = True,
conversation_id: Optional[str] = None,
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
metadata: dict | None = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> StreamResponse[AssistantCompletion]:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"model": model,
"messages": messages,
"stream": stream,
"conversation_id": conversation_id,
"attachments": attachments,
"metadata": metadata,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant",
body=maybe_transform(body, assistant_create_params.AssistantParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantCompletion,
stream=stream or True,
stream_cls=StreamResponse[AssistantCompletion],
)
def query_support(
self,
*,
assistant_id_list: Optional[list[str]] = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AssistantSupportResp:
body = deepcopy_minimal(
{
"assistant_id_list": assistant_id_list,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/list",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantSupportResp,
)
def query_conversation_usage(
self,
assistant_id: str,
page: int = 1,
page_size: int = 10,
*,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ConversationUsageListResp:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"page": page,
"page_size": page_size,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/conversation/list",
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=ConversationUsageListResp,
)

View File

@ -1,146 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Optional
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
from ..core.pagination import SyncCursorPage
from ..types import batch_create_params, batch_list_params
from ..types.batch import Batch
if TYPE_CHECKING:
from .._client import ZhipuAI
class Batches(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
completion_window: str | None = None,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
input_file_id: str,
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
auto_delete_input_file: bool = True,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
return self._post(
"/batches",
body=maybe_transform(
{
"completion_window": completion_window,
"endpoint": endpoint,
"input_file_id": input_file_id,
"metadata": metadata,
"auto_delete_input_file": auto_delete_input_file,
},
batch_create_params.BatchCreateParams,
),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
def retrieve(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Retrieves a batch.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._get(
f"/batches/{batch_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> SyncCursorPage[Batch]:
"""List your organization's batches.
Args:
after: A cursor for use in pagination.
`after` is an object ID that defines your place
in the list. For instance, if you make a list request and receive 100 objects,
ending with obj_foo, your subsequent call can include after=obj_foo in order to
fetch the next page of the list.
limit: A limit on the number of objects to be returned. Limit can range between 1 and
100, and the default is 20.
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get_api_list(
"/batches",
page=SyncCursorPage[Batch],
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"limit": limit,
},
batch_list_params.BatchListParams,
),
),
model=Batch,
)
def cancel(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Cancels an in-progress batch.
Args:
batch_id: The ID of the batch to cancel.
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._post(
f"/batches/{batch_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)

View File

@ -1,5 +0,0 @@
from .async_completions import AsyncCompletions
from .chat import Chat
from .completions import Completions
__all__ = ["AsyncCompletions", "Chat", "Completions"]

View File

@ -1,115 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
drop_prefix_image_data,
make_request_options,
maybe_transform,
)
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
from ...types.chat.code_geex import code_geex_params
from ...types.sensitive_word_check import SensitiveWordCheckRequest
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
class AsyncCompletions(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], list[list[int]], None],
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncTaskStatus:
_cast_type = AsyncTaskStatus
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if temperature is not None and temperature != NOT_GIVEN:
if temperature <= 0:
do_sample = False
temperature = 0.01
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间do_sample重写为:false参数top_p temperature不生效") # noqa: E501
if temperature >= 1:
temperature = 0.99
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
if top_p is not None and top_p != NOT_GIVEN:
if top_p >= 1:
top_p = 0.99
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if top_p <= 0:
top_p = 0.01
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if isinstance(messages, list):
for item in messages:
if item.get("content"):
item["content"] = drop_prefix_image_data(item["content"])
body = {
"model": model,
"request_id": request_id,
"user_id": user_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"tools": tools,
"tool_choice": tool_choice,
"meta": meta,
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
}
return self._post(
"/async/chat/completions",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
stream=False,
)
def retrieve_completion_result(
self,
id: str,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Union[AsyncCompletion, AsyncTaskStatus]:
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
return self._get(
path=f"/async-result/{id}",
cast_type=_cast_type,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
)

View File

@ -1,18 +0,0 @@
from typing import TYPE_CHECKING
from ...core import BaseAPI, cached_property
from .async_completions import AsyncCompletions
from .completions import Completions
if TYPE_CHECKING:
pass
class Chat(BaseAPI):
@cached_property
def completions(self) -> Completions:
return Completions(self._client)
@cached_property
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
return AsyncCompletions(self._client)

View File

@ -1,108 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
drop_prefix_image_data,
make_request_options,
maybe_transform,
)
from ...types.chat.chat_completion import Completion
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
from ...types.chat.code_geex import code_geex_params
from ...types.sensitive_word_check import SensitiveWordCheckRequest
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
class Completions(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], object, None],
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | StreamResponse[ChatCompletionChunk]:
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if temperature is not None and temperature != NOT_GIVEN:
if temperature <= 0:
do_sample = False
temperature = 0.01
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间do_sample重写为:false参数top_p temperature不生效") # noqa: E501
if temperature >= 1:
temperature = 0.99
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
if top_p is not None and top_p != NOT_GIVEN:
if top_p >= 1:
top_p = 0.99
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
if top_p <= 0:
top_p = 0.01
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
if isinstance(messages, list):
for item in messages:
if item.get("content"):
item["content"] = drop_prefix_image_data(item["content"])
body = deepcopy_minimal(
{
"model": model,
"request_id": request_id,
"user_id": user_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"stream": stream,
"tools": tools,
"tool_choice": tool_choice,
"meta": meta,
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
}
)
return self._post(
"/chat/completions",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Completion,
stream=stream or False,
stream_cls=StreamResponse[ChatCompletionChunk],
)

View File

@ -1,50 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
from ..types.embeddings import EmbeddingsResponded
if TYPE_CHECKING:
from .._client import ZhipuAI
class Embeddings(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
input: Union[str, list[str], list[int], list[list[int]]],
model: Union[str],
dimensions: Union[int] | NotGiven = NOT_GIVEN,
encoding_format: str | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EmbeddingsResponded:
_cast_type = EmbeddingsResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/embeddings",
body={
"input": input,
"model": model,
"dimensions": dimensions,
"encoding_format": encoding_format,
"user": user,
"request_id": request_id,
"sensitive_word_check": sensitive_word_check,
},
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
stream=False,
)

View File

@ -1,194 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, Optional, cast
import httpx
from ..core import (
NOT_GIVEN,
BaseAPI,
Body,
FileTypes,
Headers,
NotGiven,
_legacy_binary_response,
_legacy_response,
deepcopy_minimal,
extract_files,
make_request_options,
maybe_transform,
)
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
if TYPE_CHECKING:
from .._client import ZhipuAI
__all__ = ["Files", "FilesWithRawResponse"]
class Files(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
file: Optional[FileTypes] = None,
upload_detail: Optional[list[UploadDetail]] = None,
purpose: Literal["fine-tune", "retrieval", "batch"],
knowledge_id: Optional[str] = None,
sentence_size: Optional[int] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileObject:
if not file and not upload_detail:
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"upload_detail": upload_detail,
"purpose": purpose,
"knowledge_id": knowledge_id,
"sentence_size": sentence_size,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body=maybe_transform(body, file_create_params.FileCreateParams),
files=files,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FileObject,
)
# def retrieve(
# self,
# file_id: str,
# *,
# extra_headers: Headers | None = None,
# extra_body: Body | None = None,
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
# ) -> FileObject:
# """
# Returns information about a specific file.
#
# Args:
# file_id: The ID of the file to retrieve information about
# extra_headers: Send extra headers
#
# extra_body: Add additional JSON properties to the request
#
# timeout: Override the client-level default timeout for this request, in seconds
# """
# if not file_id:
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
# return self._get(
# f"/files/{file_id}",
# options=make_request_options(
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
# ),
# cast_type=FileObject,
# )
def list(
self,
*,
purpose: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
after: str | NotGiven = NOT_GIVEN,
order: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFileObject:
return self._get(
"/files",
cast_type=ListOfFileObject,
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"purpose": purpose,
"limit": limit,
"after": after,
"order": order,
},
),
)
def delete(
self,
file_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileDeleted:
"""
Delete a file.
Args:
file_id: The ID of the file to delete
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not file_id:
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
return self._delete(
f"/files/{file_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FileDeleted,
)
def content(
self,
file_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> _legacy_response.HttpxBinaryResponseContent:
"""
Returns the contents of the specified file.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not file_id:
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
return self._get(
f"/files/{file_id}/content",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
)
class FilesWithRawResponse:
def __init__(self, files: Files) -> None:
self._files = files
self.create = _legacy_response.to_raw_response_wrapper(
files.create,
)
self.list = _legacy_response.to_raw_response_wrapper(
files.list,
)
self.content = _legacy_response.to_raw_response_wrapper(
files.content,
)

View File

@ -1,5 +0,0 @@
from .fine_tuning import FineTuning
from .jobs import Jobs
from .models import FineTunedModels
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]

View File

@ -1,18 +0,0 @@
from typing import TYPE_CHECKING
from ...core import BaseAPI, cached_property
from .jobs import Jobs
from .models import FineTunedModels
if TYPE_CHECKING:
pass
class FineTuning(BaseAPI):
@cached_property
def jobs(self) -> Jobs:
return Jobs(self._client)
@cached_property
def models(self) -> FineTunedModels:
return FineTunedModels(self._client)

View File

@ -1,3 +0,0 @@
from .jobs import Jobs
__all__ = ["Jobs"]

View File

@ -1,152 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
make_request_options,
)
from ....types.fine_tuning import (
FineTuningJob,
FineTuningJobEvent,
ListOfFineTuningJob,
job_create_params,
)
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["Jobs"]
class Jobs(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
model: str,
training_file: str,
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
suffix: Optional[str] | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._post(
"/fine_tuning/jobs",
body={
"model": model,
"training_file": training_file,
"hyperparameters": hyperparameters,
"suffix": suffix,
"validation_file": validation_file,
"request_id": request_id,
},
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def retrieve(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFineTuningJob:
return self._get(
"/fine_tuning/jobs",
cast_type=ListOfFineTuningJob,
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
def cancel(
self,
fine_tuning_job_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
if not fine_tuning_job_id:
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
return self._post(
f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)
def list_events(
self,
fine_tuning_job_id: str,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJobEvent:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
cast_type=FineTuningJobEvent,
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
def delete(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
if not fine_tuning_job_id:
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
return self._delete(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTuningJob,
)

View File

@ -1,3 +0,0 @@
from .fine_tuned_models import FineTunedModels
__all__ = ["FineTunedModels"]

View File

@ -1,41 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
make_request_options,
)
from ....types.fine_tuning.models import FineTunedModelsStatus
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["FineTunedModels"]
class FineTunedModels(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def delete(
self,
fine_tuned_model: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTunedModelsStatus:
if not fine_tuned_model:
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
return self._delete(
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=FineTunedModelsStatus,
)

View File

@ -1,59 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
from ..types.image import ImagesResponded
from ..types.sensitive_word_check import SensitiveWordCheckRequest
if TYPE_CHECKING:
from .._client import ZhipuAI
class Images(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def generations(
self,
*,
prompt: str,
model: str | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
quality: Optional[str] | NotGiven = NOT_GIVEN,
response_format: Optional[str] | NotGiven = NOT_GIVEN,
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
user_id: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded:
_cast_type = ImagesResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/images/generations",
body={
"prompt": prompt,
"model": model,
"n": n,
"quality": quality,
"response_format": response_format,
"sensitive_word_check": sensitive_word_check,
"size": size,
"style": style,
"user": user,
"user_id": user_id,
"request_id": request_id,
},
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=_cast_type,
stream=False,
)

View File

@ -1,3 +0,0 @@
from .knowledge import Knowledge
__all__ = ["Knowledge"]

View File

@ -1,3 +0,0 @@
from .document import Document
__all__ = ["Document"]

View File

@ -1,217 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, Optional, cast
import httpx
from ....core import (
NOT_GIVEN,
BaseAPI,
Body,
FileTypes,
Headers,
NotGiven,
deepcopy_minimal,
extract_files,
make_request_options,
maybe_transform,
)
from ....types.files import UploadDetail, file_create_params
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
from ....types.knowledge.document.document_list_resp import DocumentPage
if TYPE_CHECKING:
from ...._client import ZhipuAI
__all__ = ["Document"]
class Document(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def create(
self,
*,
file: Optional[FileTypes] = None,
custom_separator: Optional[list[str]] = None,
upload_detail: Optional[list[UploadDetail]] = None,
purpose: Literal["retrieval"],
knowledge_id: Optional[str] = None,
sentence_size: Optional[int] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentObject:
if not file and not upload_detail:
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
body = deepcopy_minimal(
{
"file": file,
"upload_detail": upload_detail,
"purpose": purpose,
"custom_separator": custom_separator,
"knowledge_id": knowledge_id,
"sentence_size": sentence_size,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
if files:
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body=maybe_transform(body, file_create_params.FileCreateParams),
files=files,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=DocumentObject,
)
def edit(
self,
document_id: str,
knowledge_type: str,
*,
custom_separator: Optional[list[str]] = None,
sentence_size: Optional[int] = None,
callback_url: Optional[str] = None,
callback_header: Optional[dict[str, str]] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Args:
document_id: 知识id
knowledge_type: 知识类型:
1:文章知识: 支持pdf,url,docx
2.问答知识-文档: 支持pdf,url,docx
3.问答知识-表格: 支持xlsx
4.商品库-表格: 支持xlsx
5.自定义: 支持pdf,url,docx
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
:param knowledge_type:
:param document_id:
:param timeout:
:param extra_body:
:param callback_header:
:param sentence_size:
:param extra_headers:
:param callback_url:
:param custom_separator:
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
body = deepcopy_minimal(
{
"id": document_id,
"knowledge_type": knowledge_type,
"custom_separator": custom_separator,
"sentence_size": sentence_size,
"callback_url": callback_url,
"callback_header": callback_header,
}
)
return self._put(
f"/document/{document_id}",
body=maybe_transform(body, document_edit_params.DocumentEditParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def list(
self,
knowledge_id: str,
*,
purpose: str | NotGiven = NOT_GIVEN,
page: str | NotGiven = NOT_GIVEN,
limit: str | NotGiven = NOT_GIVEN,
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentPage:
return self._get(
"/files",
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"knowledge_id": knowledge_id,
"purpose": purpose,
"page": page,
"limit": limit,
"order": order,
},
document_list_params.DocumentListParams,
),
),
cast_type=DocumentPage,
)
def delete(
self,
document_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Delete a file.
Args:
document_id: 知识id
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
return self._delete(
f"/document/{document_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def retrieve(
self,
document_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> DocumentData:
"""
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not document_id:
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
return self._get(
f"/document/{document_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=DocumentData,
)

View File

@ -1,173 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
cached_property,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
from ...types.knowledge.knowledge_list_resp import KnowledgePage
from .document import Document
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Knowledge"]
class Knowledge(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
@cached_property
def document(self) -> Document:
return Document(self._client)
def create(
self,
embedding_id: int,
name: str,
*,
customer_identifier: Optional[str] = None,
description: Optional[str] = None,
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
bucket_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgeInfo:
body = deepcopy_minimal(
{
"embedding_id": embedding_id,
"name": name,
"customer_identifier": customer_identifier,
"description": description,
"background": background,
"icon": icon,
"bucket_id": bucket_id,
}
)
return self._post(
"/knowledge",
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=KnowledgeInfo,
)
def modify(
self,
knowledge_id: str,
embedding_id: int,
*,
name: str,
description: Optional[str] = None,
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
body = deepcopy_minimal(
{
"id": knowledge_id,
"embedding_id": embedding_id,
"name": name,
"description": description,
"background": background,
"icon": icon,
}
)
return self._put(
f"/knowledge/{knowledge_id}",
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def query(
self,
*,
page: int | NotGiven = 1,
size: int | NotGiven = 10,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgePage:
return self._get(
"/knowledge",
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"page": page,
"size": size,
},
knowledge_list_params.KnowledgeListParams,
),
),
cast_type=KnowledgePage,
)
def delete(
self,
knowledge_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> httpx.Response:
"""
Delete a file.
Args:
knowledge_id: 知识库ID
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not knowledge_id:
raise ValueError("Expected a non-empty value for `knowledge_id`")
return self._delete(
f"/knowledge/{knowledge_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=httpx.Response,
)
def used(
self,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> KnowledgeUsed:
"""
Returns the contents of the specified file.
Args:
extra_headers: Send extra headers
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get(
"/knowledge/capacity",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=KnowledgeUsed,
)

View File

@ -1,3 +0,0 @@
from .tools import Tools
__all__ = ["Tools"]

View File

@ -1,65 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Literal, Optional, Union
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Tools"]
class Tools(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def web_search(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
messages: Union[str, list[str], list[int], object, None],
scope: Optional[str] | NotGiven = NOT_GIVEN,
location: Optional[str] | NotGiven = NOT_GIVEN,
recent_days: Optional[int] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> WebSearch | StreamResponse[WebSearchChunk]:
body = deepcopy_minimal(
{
"model": model,
"request_id": request_id,
"messages": messages,
"stream": stream,
"scope": scope,
"location": location,
"recent_days": recent_days,
}
)
return self._post(
"/tools",
body=maybe_transform(body, tools_web_search_params.WebSearchParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=WebSearch,
stream=stream or False,
stream_cls=StreamResponse[WebSearchChunk],
)

View File

@ -1,7 +0,0 @@
from .videos import (
Videos,
)
__all__ = [
"Videos",
]

View File

@ -1,77 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import httpx
from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.sensitive_word_check import SensitiveWordCheckRequest
from ...types.video import VideoObject, video_create_params
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Videos"]
class Videos(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)
def generations(
self,
model: str,
*,
prompt: Optional[str] = None,
image_url: Optional[str] = None,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> VideoObject:
if not model and not model:
raise ValueError("At least one of `model` and `prompt` must be provided.")
body = deepcopy_minimal(
{
"model": model,
"prompt": prompt,
"image_url": image_url,
"sensitive_word_check": sensitive_word_check,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/videos/generations",
body=maybe_transform(body, video_create_params.VideoCreateParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=VideoObject,
)
def retrieve_videos_result(
self,
id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> VideoObject:
if not id:
raise ValueError("At least one of `id` must be provided.")
return self._get(
f"/async-result/{id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=VideoObject,
)

View File

@ -1,108 +0,0 @@
from ._base_api import BaseAPI
from ._base_compat import (
PYDANTIC_V2,
ConfigDict,
GenericModel,
cached_property,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._base_models import BaseModel, construct_type
from ._base_type import (
NOT_GIVEN,
Body,
FileTypes,
Headers,
IncEx,
ModelT,
NotGiven,
Query,
)
from ._constants import (
ZHIPUAI_DEFAULT_LIMITS,
ZHIPUAI_DEFAULT_MAX_RETRIES,
ZHIPUAI_DEFAULT_TIMEOUT,
)
from ._errors import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,
APIResponseError,
APIResponseValidationError,
APIServerFlowExceedError,
APIStatusError,
APITimeoutError,
ZhipuAIError,
)
from ._files import is_file_content
from ._http_client import HttpClient, make_request_options
from ._sse_client import StreamResponse
from ._utils import (
deepcopy_minimal,
drop_prefix_image_data,
extract_files,
is_given,
is_list,
is_mapping,
maybe_transform,
parse_date,
parse_datetime,
)
__all__ = [
"BaseModel",
"construct_type",
"BaseAPI",
"NOT_GIVEN",
"Headers",
"NotGiven",
"Body",
"IncEx",
"ModelT",
"Query",
"FileTypes",
"PYDANTIC_V2",
"ConfigDict",
"GenericModel",
"get_args",
"is_union",
"parse_obj",
"get_origin",
"is_literal_type",
"get_model_config",
"get_model_fields",
"field_get_default",
"is_file_content",
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
"make_request_options",
"HttpClient",
"ZHIPUAI_DEFAULT_TIMEOUT",
"ZHIPUAI_DEFAULT_MAX_RETRIES",
"ZHIPUAI_DEFAULT_LIMITS",
"is_list",
"is_mapping",
"parse_date",
"parse_datetime",
"is_given",
"maybe_transform",
"deepcopy_minimal",
"extract_files",
"StreamResponse",
]

View File

@ -1,19 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .._client import ZhipuAI
class BaseAPI:
_client: ZhipuAI
def __init__(self, client: ZhipuAI) -> None:
self._client = client
self._delete = client.delete
self._get = client.get
self._post = client.post
self._put = client.put
self._patch = client.patch
self._get_api_list = client.get_api_list

View File

@ -1,209 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload
import pydantic
from pydantic.fields import FieldInfo
from typing_extensions import Self
from ._base_type import StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
# v1 re-exports
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: ...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
def is_union(tp: type[Any] | None) -> bool: ...
def get_origin(t: type[Any]) -> type[Any] | None: ...
def is_literal_type(type_: type[Any]) -> bool: ...
def is_typeddict(type_: type[Any]) -> bool: ...
else:
if PYDANTIC_V2:
from pydantic.v1.typing import ( # noqa: I001
get_args as get_args, # noqa: PLC0414
is_union as is_union, # noqa: PLC0414
get_origin as get_origin, # noqa: PLC0414
is_typeddict as is_typeddict, # noqa: PLC0414
is_literal_type as is_literal_type, # noqa: PLC0414
)
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
else:
from pydantic.typing import ( # noqa: I001
get_args as get_args, # noqa: PLC0414
is_union as is_union, # noqa: PLC0414
get_origin as get_origin, # noqa: PLC0414
is_typeddict as is_typeddict, # noqa: PLC0414
is_literal_type as is_literal_type, # noqa: PLC0414
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict
else:
if PYDANTIC_V2:
from pydantic import ConfigDict
else:
# TODO: provide an error message here?
ConfigDict = None
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(value)
else:
# pyright: ignore[reportDeprecated, reportUnnecessaryCast]
return cast(_ModelT, model.parse_obj(value))
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V2:
return field.is_required()
return field.required # type: ignore
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V2:
return field.annotation
return field.outer_type_ # type: ignore
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V2:
return model.model_config
return model.__config__ # type: ignore
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V2:
return model.model_fields
return model.__fields__ # type: ignore
def model_copy(model: _ModelT) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy()
return model.copy() # type: ignore
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V2:
return model.model_dump_json(indent=indent)
return model.json(indent=indent) # type: ignore
def model_dump(
model: pydantic.BaseModel,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_dump(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel): ...
else:
if PYDANTIC_V2:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
else:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]): # noqa: N801
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None: ...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
typed_cached_property = cached_property

View File

@ -1,670 +0,0 @@
from __future__ import annotations
import inspect
import os
from collections.abc import Callable
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from typing_extensions import (
ParamSpec,
Protocol,
override,
runtime_checkable,
)
from ._base_compat import (
PYDANTIC_V2,
ConfigDict,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._base_compat import (
GenericModel as BaseGenericModel,
)
from ._base_type import (
IncEx,
ModelT,
)
from ._utils import (
PropertyInfo,
coerce_boolean,
extract_type_arg,
is_annotated_type,
is_list,
is_mapping,
parse_date,
parse_datetime,
strip_annotated_type,
)
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField
__all__ = ["BaseModel", "GenericModel"]
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
_T = TypeVar("_T")
P = ParamSpec("P")
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
else:
@property
@override
def model_fields_set(self) -> set[str]:
# a forwards-compat shim for pydantic v2
return self.__fields_set__ # type: ignore
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
""" # noqa: E501
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
""" # noqa: E501
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
# Override the 'construct' method in a way that supports recursive parsing without validation.
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
@classmethod
@override
def construct(
cls: type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = cls.__new__(cls)
fields_values: dict[str, object] = {}
config = get_model_config(cls)
populate_by_name = (
config.allow_population_by_field_name
if isinstance(config, _ConfigProtocol)
else config.get("populate_by_name")
)
if _fields_set is None:
_fields_set = set()
model_fields = get_model_fields(cls)
for name, field in model_fields.items():
key = field.alias
if key is None or (key not in values and populate_by_name):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
_extra = {}
for key, value in values.items():
if key not in model_fields:
if PYDANTIC_V2:
_extra[key] = value
else:
_fields_set.add(key)
fields_values[key] = value
object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801
if PYDANTIC_V2:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801
object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801
object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801
else:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801
return m
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
# because the type signatures are technically different
# although not in practice
model_construct = construct
if not PYDANTIC_V2:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using
@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Args:
mode: The mode in which `to_python` should run.
If mode is 'json', the dictionary will only contain JSON serializable types.
If mode is 'python', the dictionary may contain any Python objects.
include: A list of fields to include in the output.
exclude: A list of fields to exclude from the output.
by_alias: Whether to use the field's alias in the dictionary key if defined.
exclude_unset: Whether to exclude fields that are unset or None from the output.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
round_trip: Whether to enable serialization and deserialization round-trip support.
warnings: Whether to log warnings when invalid fields are encountered.
Returns:
A dictionary representation of the model.
"""
if mode != "python":
raise ValueError("mode is only supported in Pydantic v2")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
@override
def model_dump_json(
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Generates a JSON representation of the model using Pydantic's `to_json` method.
Args:
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
by_alias: Whether to serialize using field aliases.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to use serialization/deserialization between JSON and class instance.
warnings: Whether to show any warnings that occurred during serialization.
Returns:
A JSON string representation of the model.
"""
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V2:
type_ = field.annotation
else:
type_ = cast(type, field.outer_type_) # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_)
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
return any(is_basemodel(variant) for variant in get_args(type_))
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
"""Loose coercion to the expected type with construction of nested values.
Note: the returned value from this function is not guaranteed to match the
given type.
"""
return cast(_T, construct_type(value=value, type_=type_))
def construct_type(*, value: object, type_: type) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
meta: tuple[Any, ...] = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = ()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
# Subclassing here confuses type checkers, so we treat this class as non-inheriting.
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if PYDANTIC_V2:
from pydantic import TypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING:
class TypeAdapter(Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
def __init__(self, type_: type[_T]):
self.type_ = type_
def validate_python(self, value: Any) -> _T:
if not isinstance(value, self.type_):
raise ValueError(f"Invalid type: {value} is not of type {self.type_}")
return value
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)

View File

@ -1,170 +0,0 @@
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
Literal,
Optional,
TypeAlias,
TypeVar,
Union,
)
import pydantic
from httpx import Response
from typing_extensions import Protocol, TypedDict, override, runtime_checkable
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
PrimitiveData = Union[str, int, float, bool, None]
Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
if TYPE_CHECKING:
NoneType: type[None]
else:
NoneType = type(None)
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit:
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post('/upload/files', files={'file': b'my raw file content'})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={'Content-Type': 'multipart/form-data'})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={'Content-Type': Omit()})
```
"""
def __bool__(self) -> Literal[False]:
return False
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None: ...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
# for user input files
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
# duplicate of the above but without our custom file support
HttpxFileContent = Union[bytes, IO[bytes]]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]]

View File

@ -1,12 +0,0 @@
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0`
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
# 通过 `retry` 参数控制重试次数默认为3次
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10`
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0

View File

@ -1,86 +0,0 @@
from __future__ import annotations
import httpx
__all__ = [
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
"APIConnectionError",
]
class ZhipuAIError(Exception):
def __init__(
self,
message: str,
) -> None:
super().__init__(message)
class APIStatusError(ZhipuAIError):
response: httpx.Response
status_code: int
def __init__(self, message: str, *, response: httpx.Response) -> None:
super().__init__(message)
self.response = response
self.status_code = response.status_code
class APIRequestFailedError(APIStatusError): ...
class APIAuthenticationError(APIStatusError): ...
class APIReachLimitError(APIStatusError): ...
class APIInternalError(APIStatusError): ...
class APIServerFlowExceedError(APIStatusError): ...
class APIResponseError(ZhipuAIError):
message: str
request: httpx.Request
json_data: object
def __init__(self, message: str, request: httpx.Request, json_data: object):
self.message = message
self.request = request
self.json_data = json_data
super().__init__(message)
class APIResponseValidationError(APIResponseError):
status_code: int
response: httpx.Response
def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None:
super().__init__(
message=message or "Data returned by API invalid for expected schema.",
request=response.request,
json_data=json_data,
)
self.response = response
self.status_code = response.status_code
class APIConnectionError(APIResponseError):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, json_data=None)
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)

Some files were not shown because too many files have changed in this diff Show More