feat: support o1 series models for openrouter (#8358)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions

This commit is contained in:
sino 2024-09-22 10:23:50 +08:00 committed by GitHub
parent 6c2fa8defc
commit 6d56d5c1f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 126 additions and 3 deletions

View File

@ -1,3 +1,5 @@
- openai/o1-preview
- openai/o1-mini
- openai/gpt-4o - openai/gpt-4o
- openai/gpt-4o-mini - openai/gpt-4o-mini
- openai/gpt-4 - openai/gpt-4

View File

@ -1,7 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
@ -26,7 +26,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
self._update_credential(model, credentials) self._update_credential(model, credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
self._update_credential(model, credentials) self._update_credential(model, credentials)
@ -46,8 +46,49 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
self._update_credential(model, credentials) self._update_credential(model, credentials)
block_as_stream = False
if model.startswith("openai/o1"):
block_as_stream = True
stop = None
# invoke block as stream
if stream and block_as_stream:
return self._generate_block_as_stream(
model, credentials, prompt_messages, model_parameters, tools, stop, user
)
else:
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _generate_block_as_stream(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
user: Optional[str] = None,
) -> Generator:
resp: LLMResult = super()._generate(
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=resp.message,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=resp.usage.prompt_tokens,
completion_tokens=resp.usage.completion_tokens,
),
finish_reason="stop",
),
)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
self._update_credential(model, credentials) self._update_credential(model, credentials)

View File

@ -0,0 +1,40 @@
model: openai/o1-mini
label:
en_US: o1-mini
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 65536
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "3.00"
output: "12.00"
unit: "0.000001"
currency: USD

View File

@ -0,0 +1,40 @@
model: openai/o1-preview
label:
en_US: o1-preview
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 32768
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "15.00"
output: "60.00"
unit: "0.000001"
currency: USD