mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: add xAI model provider (#10272)
This commit is contained in:
parent
4847548779
commit
bf9349c4dc
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>
|
After Width: | Height: | Size: 356 B |
63
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Normal file
63
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Normal file
|
@ -0,0 +1,63 @@
|
|||
model: grok-beta
|
||||
label:
|
||||
en_US: Grok beta
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
37
api/core/model_runtime/model_providers/x/llm/llm.py
Normal file
37
api/core/model_runtime/model_providers/x/llm/llm.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
25
api/core/model_runtime/model_providers/x/x.py
Normal file
25
api/core/model_runtime/model_providers/x/x.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
38
api/core/model_runtime/model_providers/x/x.yaml
Normal file
38
api/core/model_runtime/model_providers/x/x.yaml
Normal file
|
@ -0,0 +1,38 @@
|
|||
provider: x
|
||||
label:
|
||||
en_US: xAI
|
||||
description:
|
||||
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
|
||||
icon_small:
|
||||
en_US: x-ai-logo.svg
|
||||
icon_large:
|
||||
en_US: x-ai-logo.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from xAI
|
||||
zh_Hans: 从 xAI 获取 token
|
||||
url:
|
||||
en_US: https://x.ai/api
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.x.ai/v1
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
|
@ -95,3 +95,7 @@ GPUSTACK_API_KEY=
|
|||
|
||||
# Gitee AI Credentials
|
||||
GITEE_AI_API_KEY=
|
||||
|
||||
# xAI Credentials
|
||||
XAI_API_KEY=
|
||||
XAI_API_BASE=
|
||||
|
|
204
api/tests/integration_tests/model_runtime/x/test_llm.py
Normal file
204
api/tests/integration_tests/model_runtime/x/test_llm.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = XAILargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(
|
||||
model="gpt-3.5-turbo",
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 10
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 77
|
Loading…
Reference in New Issue
Block a user