From a7cdb745c112cc7a467e73bdd8c4614b6e964ea1 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 1 Sep 2023 20:53:32 +0800 Subject: [PATCH] feat: support spark v2 validate (#1086) --- .../providers/spark_provider.py | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py index b55ea77f4d..9a3e3643a0 100644 --- a/api/core/model_providers/providers/spark_provider.py +++ b/api/core/model_providers/providers/spark_provider.py @@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider): if 'api_secret' not in credentials: raise CredentialsValidateFailedError('Spark api_secret must be provided.') - try: - credential_kwargs = { - 'app_id': credentials['app_id'], - 'api_key': credentials['api_key'], - 'api_secret': credentials['api_secret'], - } + credential_kwargs = { + 'app_id': credentials['app_id'], + 'api_key': credentials['api_key'], + 'api_secret': credentials['api_secret'], + } + try: chat_llm = ChatSpark( + model_name='spark-v2', max_tokens=10, temperature=0.01, **credential_kwargs @@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider): chat_llm(messages) except SparkError as ex: - raise CredentialsValidateFailedError(str(ex)) + # try spark v1.5 if v2.1 failed + try: + chat_llm = ChatSpark( + model_name='spark', + max_tokens=10, + temperature=0.01, + **credential_kwargs + ) + + messages = [ + HumanMessage( + content="ping" + ) + ] + + chat_llm(messages) + except SparkError as ex: + raise CredentialsValidateFailedError(str(ex)) + except Exception as ex: + logging.exception('Spark config validation failed') + raise ex except Exception as ex: logging.exception('Spark config validation failed') raise ex