feat: gemini pro function call (#3406)

This commit is contained in:
Yeuoly 2024-04-12 16:38:02 +08:00 committed by GitHub
parent 0737e930cb
commit a258a90291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 151 additions and 62 deletions

View File

@ -5,6 +5,8 @@ model_type: llm
features: features:
- agent-thought - agent-thought
- vision - vision
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -4,6 +4,8 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 30720 context_size: 30720

View File

@ -1,7 +1,9 @@
import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union
import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions import google.api_core.exceptions as exceptions
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.client as client import google.generativeai.client as client
@ -13,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
PromptMessageRole,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
@ -62,7 +64,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -95,6 +97,32 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return text.rstrip() return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
name=tool.name,
parameters=glm.Schema(
type=glm.Type.OBJECT,
properties={
key: {
'type_': value.get('type', 'string').upper(),
'description': value.get('description', ''),
'enum': value.get('enum', [])
} for key, value in tool.parameters.get('properties', {}).items()
},
required=tool.parameters.get('required', [])
),
) for tool in tools
]
)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
Validate model credentials Validate model credentials
@ -105,7 +133,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
""" """
try: try:
ping_message = PromptMessage(content="ping", role="system") ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
except Exception as ex: except Exception as ex:
@ -114,8 +142,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
def _generate(self, model: str, credentials: dict, def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
user: Optional[str] = None) -> Union[LLMResult, Generator]: stream: bool = True, user: Optional[str] = None
) -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
@ -153,7 +182,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
else: else:
history.append(content) history.append(content)
# Create a new ClientManager with tenant's API key # Create a new ClientManager with tenant's API key
new_client_manager = client._ClientManager() new_client_manager = client._ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"]) new_client_manager.configure(api_key=credentials["google_api_key"])
@ -174,7 +202,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
**config_kwargs **config_kwargs
), ),
stream=stream, stream=stream,
safety_settings=safety_settings safety_settings=safety_settings,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
) )
if stream: if stream:
@ -228,43 +257,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
""" """
index = -1 index = -1
for chunk in response: for chunk in response:
content = chunk.text for part in chunk.parts:
index += 1 assistant_prompt_message = AssistantPromptMessage(
content=''
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)
if not response._done:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
) )
else:
# calculate num tokens if part.text:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) assistant_prompt_message.content += part.text
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage if part.function_call:
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps({
key: value
for key, value in part.function_call.args.items()
})
)
)
]
yield LLMResultChunk( index += 1
model=model,
prompt_messages=prompt_messages, if not response._done:
delta=LLMResultChunkDelta(
index=index, # transform assistant message to prompt message
message=assistant_prompt_message, yield LLMResultChunk(
finish_reason=chunk.candidates[0].finish_reason, model=model,
usage=usage prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=chunk.candidates[0].finish_reason,
usage=usage
)
) )
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str: def _convert_one_message_to_text(self, message: PromptMessage) -> str:
""" """
@ -288,6 +335,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
message_text = f"{ai_prompt} {content}" message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}" message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
@ -300,26 +349,53 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:param message: one PromptMessage :param message: one PromptMessage
:return: glm Content representation of message :return: glm Content representation of message
""" """
if isinstance(message, UserPromptMessage):
parts = [] glm_content = {
if (isinstance(message.content, str)): "role": "user",
parts.append(to_part(message.content)) "parts": []
}
if (isinstance(message.content, str)):
glm_content['parts'].append(to_part(message.content))
else:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content['parts'].append(to_part(c.data))
else:
metadata, data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
glm_content['parts'].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):
glm_content = {
"role": "model",
"parts": []
}
if message.content:
glm_content['parts'].append(to_part(message.content))
if message.tool_calls:
glm_content["parts"].append(to_part(glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)))
return glm_content
elif isinstance(message, SystemPromptMessage):
return {
"role": "user",
"parts": [to_part(message.content)]
}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
"parts": [glm.Part(function_response=glm.FunctionResponse(
name=message.name,
response={
"response": message.content
}
))]
}
else: else:
for c in message.content: raise ValueError(f"Got unknown type {message}")
if c.type == PromptMessageContentType.TEXT:
parts.append(to_part(c.data))
else:
metadata, data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
parts.append(blob)
glm_content = {
"role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model",
"parts": parts
}
return glm_content
@property @property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:

View File

@ -10,6 +10,7 @@ from google.generativeai import GenerativeModel
from google.generativeai.client import _ClientManager, configure from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse from google.generativeai.types import GenerateContentResponse
from google.generativeai.types.generation_types import BaseGenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse
from google.ai.generativelanguage_v1beta.types import content as gag_content
current_api_key = '' current_api_key = ''
@ -43,6 +44,14 @@ class MockGoogleResponseClass(object):
class MockGoogleResponseCandidateClass(object): class MockGoogleResponseCandidateClass(object):
finish_reason = 'stop' finish_reason = 'stop'
@property
def content(self) -> gag_content.Content:
return gag_content.Content(
parts=[
gag_content.Part(text='it\'s google!')
]
)
class MockGoogleClass(object): class MockGoogleClass(object):
@staticmethod @staticmethod
def generate_content_sync() -> GenerateContentResponse: def generate_content_sync() -> GenerateContentResponse: