fix(model_providers/vertex_ai): Vertex AI Anthropic models authentication failed (#4971)

This commit is contained in:
Pan, Wen-Ming 2024-06-14 01:34:31 +08:00 committed by GitHub
parent f976740b57
commit f13af5a811
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
import google.api_core.exceptions as exceptions import google.api_core.exceptions as exceptions
import google.auth.transport.requests
import vertexai.generative_models as glm import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream from anthropic import AnthropicVertex, Stream
from anthropic.types import ( from anthropic.types import (
@ -44,15 +45,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class VertexAiLargeLanguageModel(LargeLanguageModel): class VertexAiLargeLanguageModel(LargeLanguageModel):
@ -95,17 +87,37 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
""" """
# use Anthropic official SDK references # use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python # - https://github.com/anthropics/anthropic-sdk-python
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
project_id = credentials["vertex_project_id"] project_id = credentials["vertex_project_id"]
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
token = ''
# get access token from service account credential
if service_account_info:
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
token = credentials.token
# Vertex AI Anthropic Claude3 Opus model avaiable in us-east5 region, Sonnet and Haiku avaiable in us-central1 region
if 'opus' in model: if 'opus' in model:
location = 'us-east5' location = 'us-east5'
else: else:
location = 'us-central1' location = 'us-central1'
client = AnthropicVertex( # use access token to authenticate
region=location, if token:
project_id=project_id client = AnthropicVertex(
) region=location,
project_id=project_id,
access_token=token
)
# When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable
else:
client = AnthropicVertex(
region=location,
project_id=project_id,
)
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
@ -462,7 +474,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
aiplatform.init(project=project_id, location=location) aiplatform.init(project=project_id, location=location)
history = [] history = []
system_instruction = GEMINI_BLOCK_MODE_PROMPT system_instruction = ""
# hack for gemini-pro-vision, which currently does not support multi-turn chat # hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-1.0-pro-vision-001": if model == "gemini-1.0-pro-vision-001":
last_msg = prompt_messages[-1] last_msg = prompt_messages[-1]