mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: gemini pro function call (#3406)
This commit is contained in:
parent
0737e930cb
commit
a258a90291
|
@ -5,6 +5,8 @@ model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- vision
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
|
|
@ -4,6 +4,8 @@ label:
|
||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 30720
|
context_size: 30720
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import google.ai.generativelanguage as glm
|
||||||
import google.api_core.exceptions as exceptions
|
import google.api_core.exceptions as exceptions
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import google.generativeai.client as client
|
import google.generativeai.client as client
|
||||||
|
@ -13,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
PromptMessageRole,
|
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -62,7 +64,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
@ -95,6 +97,32 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
return text.rstrip()
|
return text.rstrip()
|
||||||
|
|
||||||
|
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||||
|
"""
|
||||||
|
Convert tool messages to glm tools
|
||||||
|
|
||||||
|
:param tools: tool messages
|
||||||
|
:return: glm tools
|
||||||
|
"""
|
||||||
|
return glm.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
glm.FunctionDeclaration(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=glm.Schema(
|
||||||
|
type=glm.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
key: {
|
||||||
|
'type_': value.get('type', 'string').upper(),
|
||||||
|
'description': value.get('description', ''),
|
||||||
|
'enum': value.get('enum', [])
|
||||||
|
} for key, value in tool.parameters.get('properties', {}).items()
|
||||||
|
},
|
||||||
|
required=tool.parameters.get('required', [])
|
||||||
|
),
|
||||||
|
) for tool in tools
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate model credentials
|
Validate model credentials
|
||||||
|
@ -105,7 +133,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ping_message = PromptMessage(content="ping", role="system")
|
ping_message = SystemPromptMessage(content="ping")
|
||||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
@ -114,8 +142,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
stream: bool = True, user: Optional[str] = None
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
|
@ -153,7 +182,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
else:
|
else:
|
||||||
history.append(content)
|
history.append(content)
|
||||||
|
|
||||||
|
|
||||||
# Create a new ClientManager with tenant's API key
|
# Create a new ClientManager with tenant's API key
|
||||||
new_client_manager = client._ClientManager()
|
new_client_manager = client._ClientManager()
|
||||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||||
|
@ -174,7 +202,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
),
|
),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
safety_settings=safety_settings
|
safety_settings=safety_settings,
|
||||||
|
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -228,43 +257,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
index = -1
|
index = -1
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
content = chunk.text
|
for part in chunk.parts:
|
||||||
index += 1
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=''
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
|
||||||
content=content if content else '',
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response._done:
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=index,
|
|
||||||
message=assistant_prompt_message
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
|
||||||
# calculate num tokens
|
if part.text:
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
assistant_prompt_message.content += part.text
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
if part.function_call:
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
assistant_prompt_message.tool_calls = [
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id=part.function_call.name,
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=part.function_call.name,
|
||||||
|
arguments=json.dumps({
|
||||||
|
key: value
|
||||||
|
for key, value in part.function_call.args.items()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
yield LLMResultChunk(
|
index += 1
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
if not response._done:
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=index,
|
# transform assistant message to prompt message
|
||||||
message=assistant_prompt_message,
|
yield LLMResultChunk(
|
||||||
finish_reason=chunk.candidates[0].finish_reason,
|
model=model,
|
||||||
usage=usage
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=assistant_prompt_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
finish_reason=chunk.candidates[0].finish_reason,
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -288,6 +335,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
@ -300,26 +349,53 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
:param message: one PromptMessage
|
:param message: one PromptMessage
|
||||||
:return: glm Content representation of message
|
:return: glm Content representation of message
|
||||||
"""
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
parts = []
|
glm_content = {
|
||||||
if (isinstance(message.content, str)):
|
"role": "user",
|
||||||
parts.append(to_part(message.content))
|
"parts": []
|
||||||
|
}
|
||||||
|
if (isinstance(message.content, str)):
|
||||||
|
glm_content['parts'].append(to_part(message.content))
|
||||||
|
else:
|
||||||
|
for c in message.content:
|
||||||
|
if c.type == PromptMessageContentType.TEXT:
|
||||||
|
glm_content['parts'].append(to_part(c.data))
|
||||||
|
else:
|
||||||
|
metadata, data = c.data.split(',', 1)
|
||||||
|
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
||||||
|
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
|
||||||
|
glm_content['parts'].append(blob)
|
||||||
|
return glm_content
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
glm_content = {
|
||||||
|
"role": "model",
|
||||||
|
"parts": []
|
||||||
|
}
|
||||||
|
if message.content:
|
||||||
|
glm_content['parts'].append(to_part(message.content))
|
||||||
|
if message.tool_calls:
|
||||||
|
glm_content["parts"].append(to_part(glm.FunctionCall(
|
||||||
|
name=message.tool_calls[0].function.name,
|
||||||
|
args=json.loads(message.tool_calls[0].function.arguments),
|
||||||
|
)))
|
||||||
|
return glm_content
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"parts": [to_part(message.content)]
|
||||||
|
}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
return {
|
||||||
|
"role": "function",
|
||||||
|
"parts": [glm.Part(function_response=glm.FunctionResponse(
|
||||||
|
name=message.name,
|
||||||
|
response={
|
||||||
|
"response": message.content
|
||||||
|
}
|
||||||
|
))]
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
for c in message.content:
|
raise ValueError(f"Got unknown type {message}")
|
||||||
if c.type == PromptMessageContentType.TEXT:
|
|
||||||
parts.append(to_part(c.data))
|
|
||||||
else:
|
|
||||||
metadata, data = c.data.split(',', 1)
|
|
||||||
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
|
||||||
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
|
|
||||||
parts.append(blob)
|
|
||||||
|
|
||||||
glm_content = {
|
|
||||||
"role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model",
|
|
||||||
"parts": parts
|
|
||||||
}
|
|
||||||
|
|
||||||
return glm_content
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from google.generativeai import GenerativeModel
|
||||||
from google.generativeai.client import _ClientManager, configure
|
from google.generativeai.client import _ClientManager, configure
|
||||||
from google.generativeai.types import GenerateContentResponse
|
from google.generativeai.types import GenerateContentResponse
|
||||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||||
|
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||||
|
|
||||||
current_api_key = ''
|
current_api_key = ''
|
||||||
|
|
||||||
|
@ -43,6 +44,14 @@ class MockGoogleResponseClass(object):
|
||||||
class MockGoogleResponseCandidateClass(object):
|
class MockGoogleResponseCandidateClass(object):
|
||||||
finish_reason = 'stop'
|
finish_reason = 'stop'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> gag_content.Content:
|
||||||
|
return gag_content.Content(
|
||||||
|
parts=[
|
||||||
|
gag_content.Part(text='it\'s google!')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
class MockGoogleClass(object):
|
class MockGoogleClass(object):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_content_sync() -> GenerateContentResponse:
|
def generate_content_sync() -> GenerateContentResponse:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user