feat: support spark v2 validate (#1086)

This commit is contained in:
takatost 2023-09-01 20:53:32 +08:00 committed by GitHub
parent 73c86ee6a0
commit a7cdb745c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
if 'api_secret' not in credentials: if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.') raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try: credential_kwargs = {
credential_kwargs = { 'app_id': credentials['app_id'],
'app_id': credentials['app_id'], 'api_key': credentials['api_key'],
'api_key': credentials['api_key'], 'api_secret': credentials['api_secret'],
'api_secret': credentials['api_secret'], }
}
try:
chat_llm = ChatSpark( chat_llm = ChatSpark(
model_name='spark-v2',
max_tokens=10, max_tokens=10,
temperature=0.01, temperature=0.01,
**credential_kwargs **credential_kwargs
@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages) chat_llm(messages)
except SparkError as ex: except SparkError as ex:
raise CredentialsValidateFailedError(str(ex)) # try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex: except Exception as ex:
logging.exception('Spark config validation failed') logging.exception('Spark config validation failed')
raise ex raise ex