mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: support o1 series models for openrouter (#8358)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
This commit is contained in:
parent
6c2fa8defc
commit
6d56d5c1f6
|
@ -1,3 +1,5 @@
|
|||
- openai/o1-preview
|
||||
- openai/o1-mini
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
@ -26,7 +26,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._update_credential(model, credentials)
|
||||
|
@ -46,7 +46,48 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
block_as_stream = False
|
||||
if model.startswith("openai/o1"):
|
||||
block_as_stream = True
|
||||
stop = None
|
||||
|
||||
# invoke block as stream
|
||||
if stream and block_as_stream:
|
||||
return self._generate_block_as_stream(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, user
|
||||
)
|
||||
else:
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _generate_block_as_stream(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Generator:
|
||||
resp: LLMResult = super()._generate(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=resp.message,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=resp.usage.prompt_tokens,
|
||||
completion_tokens=resp.usage.completion_tokens,
|
||||
),
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
self._update_credential(model, credentials)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
model: openai/o1-mini
|
||||
label:
|
||||
en_US: o1-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 65536
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "3.00"
|
||||
output: "12.00"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
|
@ -0,0 +1,40 @@
|
|||
model: openai/o1-preview
|
||||
label:
|
||||
en_US: o1-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "15.00"
|
||||
output: "60.00"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
Loading…
Reference in New Issue
Block a user