mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Add OCI(Oracle Cloud Infrastructure) Generative AI Service as a Model Provider (#7775)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Co-authored-by: Walter Jin <jinshuhaicc@gmail.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: walter from vm <walter.jin@oracle.com>
This commit is contained in:
parent
e0d3cd91c6
commit
89aede80cc
|
@ -0,0 +1 @@
|
|||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
After Width: | Height: | Size: 874 B |
|
@ -0,0 +1 @@
|
|||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
After Width: | Height: | Size: 874 B |
|
@ -0,0 +1,52 @@
|
|||
model: cohere.command-r-16k
|
||||
label:
|
||||
en_US: cohere.command-r-16k v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
|
@ -0,0 +1,52 @@
|
|||
model: cohere.command-r-plus
|
||||
label:
|
||||
en_US: cohere.command-r-plus v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.0219'
|
||||
output: '0.0219'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
461
api/core/model_runtime/model_providers/oci/llm/llm.py
Normal file
461
api/core/model_runtime/model_providers/oci/llm/llm.py
Normal file
|
@ -0,0 +1,461 @@
|
|||
import base64
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import oci
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.command-r-plus",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"chatRequest": {
|
||||
"apiFormat": "COHERE",
|
||||
#"preambleOverride": "You are a helpful assistant.",
|
||||
#"message": "Hello!",
|
||||
#"chatHistory": [],
|
||||
"maxTokens": 600,
|
||||
"isStream": False,
|
||||
"frequencyPenalty": 0,
|
||||
"presencePenalty": 0,
|
||||
"temperature": 1,
|
||||
"topP": 0.75
|
||||
}
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
|
||||
class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
|
||||
_supported_models = {
|
||||
"meta.llama-3-70b-instruct": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": False,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-16k": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-plus": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["stream_tool_call"] if stream else feature["tool_call"]
|
||||
|
||||
def _is_multimodal_supported(self, model_id: str) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["multimodal"]
|
||||
|
||||
def _is_system_prompt_supported(self, model_id: str) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["system"]
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
#print("model"+"*"*20)
|
||||
#print(model)
|
||||
#print("credentials"+"*"*20)
|
||||
#print(credentials)
|
||||
#print("model_parameters"+"*"*20)
|
||||
#print(model_parameters)
|
||||
#print("prompt_messages"+"*"*200)
|
||||
#print(prompt_messages)
|
||||
#print("tools"+"*"*20)
|
||||
#print(tools)
|
||||
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return len(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# Setup basic variables
|
||||
# Auth Config
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"maxTokens": 5})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# config_kwargs = model_parameters.copy()
|
||||
# config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||
# if stop:
|
||||
# config_kwargs["stop_sequences"] = stop
|
||||
|
||||
# initialize client
|
||||
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
|
||||
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
|
||||
chathistory = []
|
||||
system_prompts = []
|
||||
#if "meta.llama" in model:
|
||||
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
|
||||
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
|
||||
request_args["chatRequest"].update(model_parameters)
|
||||
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
|
||||
presence_penalty = model_parameters.get("presencePenalty", 0)
|
||||
if frequency_penalty > 0 and presence_penalty > 0:
|
||||
raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")
|
||||
|
||||
# for msg in prompt_messages: # makes message roles strictly alternating
|
||||
# content = self._format_message_to_glm_content(msg)
|
||||
# if history and history[-1]["role"] == content["role"]:
|
||||
# history[-1]["parts"].extend(content["parts"])
|
||||
# else:
|
||||
# history.append(content)
|
||||
|
||||
# temporary not implement the tool call function
|
||||
valid_value = self._is_tool_call_supported(model, stream)
|
||||
if tools is not None and len(tools) > 0:
|
||||
if not valid_value:
|
||||
raise InvokeBadRequestError("Does not support function calling")
|
||||
if model.startswith("cohere"):
|
||||
#print("run cohere " * 10)
|
||||
for message in prompt_messages[:-1]:
|
||||
text = ""
|
||||
if isinstance(message.content, str):
|
||||
text = message.content
|
||||
if isinstance(message, UserPromptMessage):
|
||||
chathistory.append({"role": "USER", "message": text})
|
||||
else:
|
||||
chathistory.append({"role": "CHATBOT", "message": text})
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
system_prompts.append(message.content)
|
||||
args = {"apiFormat": "COHERE",
|
||||
"preambleOverride": ' '.join(system_prompts),
|
||||
"message": prompt_messages[-1].content,
|
||||
"chatHistory": chathistory, }
|
||||
request_args["chatRequest"].update(args)
|
||||
elif model.startswith("meta"):
|
||||
#print("run meta " * 10)
|
||||
meta_messages = []
|
||||
for message in prompt_messages:
|
||||
text = message.content
|
||||
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
|
||||
args = {"apiFormat": "GENERIC",
|
||||
"messages": meta_messages,
|
||||
"numGenerations": 1,
|
||||
"topK": -1}
|
||||
request_args["chatRequest"].update(args)
|
||||
|
||||
if stream:
|
||||
request_args["chatRequest"]["isStream"] = True
|
||||
#print("final request" + "|" * 20)
|
||||
#print(request_args)
|
||||
response = client.chat(request_args)
|
||||
#print(vars(response))
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.data.chat_response.text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
events = response.data.events()
|
||||
for stream in events:
|
||||
chunk = json.loads(stream.data)
|
||||
#print(chunk)
|
||||
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
|
||||
|
||||
|
||||
|
||||
#for chunk in response:
|
||||
#for part in chunk.parts:
|
||||
#if part.function_call:
|
||||
# assistant_prompt_message.tool_calls = [
|
||||
# AssistantPromptMessage.ToolCall(
|
||||
# id=part.function_call.name,
|
||||
# type='function',
|
||||
# function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
# name=part.function_call.name,
|
||||
# arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
# )
|
||||
# )
|
||||
# ]
|
||||
|
||||
if "finishReason" not in chunk:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=''
|
||||
)
|
||||
if model.startswith("cohere"):
|
||||
if chunk["text"]:
|
||||
assistant_prompt_message.content += chunk["text"]
|
||||
elif model.startswith("meta"):
|
||||
assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
|
||||
index += 1
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk["finishReason"]),
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(
|
||||
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||
)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: []
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
model: meta.llama-3-70b-instruct
|
||||
label:
|
||||
zh_Hans: meta.llama-3-70b-instruct
|
||||
en_US: meta.llama-3-70b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 2.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 8000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.015'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
34
api/core/model_runtime/model_providers/oci/oci.py
Normal file
34
api/core/model_runtime/model_providers/oci/oci.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCIGENAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `cohere.command-r-plus` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='cohere.command-r-plus',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
|
||||
|
42
api/core/model_runtime/model_providers/oci/oci.yaml
Normal file
42
api/core/model_runtime/model_providers/oci/oci.yaml
Normal file
|
@ -0,0 +1,42 @@
|
|||
provider: oci
|
||||
label:
|
||||
en_US: OCIGenerativeAI
|
||||
description:
|
||||
en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
|
||||
zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from OCI
|
||||
zh_Hans: 从 OCI 获取 API Key
|
||||
url:
|
||||
en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
#- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
#- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: oci_config_content
|
||||
label:
|
||||
en_US: oci api key config file's content
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||
en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||
- variable: oci_key_content
|
||||
label:
|
||||
en_US: oci api key file's content
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
|
||||
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))
|
|
@ -0,0 +1,5 @@
|
|||
- cohere.embed-english-light-v2.0
|
||||
- cohere.embed-english-light-v3.0
|
||||
- cohere.embed-english-v3.0
|
||||
- cohere.embed-multilingual-light-v3.0
|
||||
- cohere.embed-multilingual-v3.0
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-english-light-v2.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-english-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-english-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-multilingual-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: cohere.embed-multilingual-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
|
@ -0,0 +1,242 @@
|
|||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import oci
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.embed-english-light-v3.0",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"truncate": "NONE",
|
||||
"inputs": [""]
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=inputs[i: i + max_chunks]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
characters = 0
|
||||
for text in texts:
|
||||
characters += len(text)
|
||||
return characters
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
|
||||
# oci
|
||||
# initialize client
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
request_args["inputs"] = texts
|
||||
response = client.embed_text(request_args)
|
||||
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
915
api/poetry.lock
generated
915
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
@ -190,6 +190,7 @@ zhipuai = "1.0.7"
|
|||
azure-ai-ml = "^1.19.0"
|
||||
azure-ai-inference = "^1.0.0b3"
|
||||
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
|
||||
oci = "^2.133.0"
|
||||
[tool.poetry.group.indriect.dependencies]
|
||||
kaleido = "0.2.1"
|
||||
rank-bm25 = "~0.2.2"
|
||||
|
|
130
api/tests/integration_tests/model_runtime/oci/test_llm.py
Normal file
130
api/tests/integration_tests/model_runtime/oci/test_llm.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OCILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="cohere.command-r-plus",
|
||||
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="cohere.command-r-plus",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OCILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="cohere.command-r-plus",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = OCILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="meta.llama-3-70b-instruct",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_model_with_function():
|
||||
model = OCILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="cohere.command-r-plus",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OCILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="cohere.command-r-plus",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
|
@ -0,0 +1,20 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.oci.oci import OCIGENAIProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = OCIGENAIProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.oci.text_embedding.text_embedding import OCITextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OCITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="cohere.embed-multilingual-v3.0",
|
||||
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="cohere.embed-multilingual-v3.0",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OCITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="cohere.embed-multilingual-v3.0",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 4
|
||||
# assert result.usage.total_tokens == 811
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OCITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="cohere.embed-multilingual-v3.0",
|
||||
credentials={
|
||||
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
Loading…
Reference in New Issue
Block a user