enhance:speedup xinference embedding & rerank (#3587)

This commit is contained in:
呆萌闷油瓶 2024-04-18 16:54:00 +08:00 committed by GitHub
parent b4d2d635f7
commit 4365843c20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 23 deletions

View File

@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel):
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
response = xinference_client.rerank(
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
response = handle.rerank(
documents=docs,
query=query,
top_n=top_n,
@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel):
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a rerank model')
self.invoke(
model=model,
@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel):
parameter_rules=[]
)
return entity
return entity

View File

@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
raise InvokeAuthorizationError(e)
if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
try:
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
embeddings = handle.create_embedding(input=texts)
except RuntimeError as e:
raise InvokeServerUnavailableError(e)
@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
raise InvokeAuthorizationError(e)
if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError as e:
@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
parameter_rules=[]
)
return entity
return entity