diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index dd25037d34..17b85862c9 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] - ) - - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulRerankModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') - - response = xinference_client.rerank( + handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) + response = handle.rerank( documents=docs, query=query, top_n=top_n, @@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulRerankModelHandle): + raise InvokeBadRequestError( + 'please check model type, the model you want to invoke is not a rerank model') self.invoke( model=model, @@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel): parameter_rules=[] ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 32d2b1516d..e8429cecd4 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) - - try: - handle = client.get_model(model_uid=model_uid) - except RuntimeError as e: - raise InvokeAuthorizationError(e) - - if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') - try: + handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) embeddings = handle.create_embedding(input=texts) except RuntimeError as e: raise InvokeServerUnavailableError(e) @@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens + if server_url.endswith('/'): + server_url = server_url[:-1] + + client = Client(base_url=server_url) + + try: + handle = client.get_model(model_uid=model_uid) + except RuntimeError as e: + raise InvokeAuthorizationError(e) + + if not isinstance(handle, RESTfulEmbeddingModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') self._invoke(model=model, credentials=credentials, texts=['ping']) except InvokeAuthorizationError as e: @@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): parameter_rules=[] ) - return entity \ No newline at end of file + return entity