mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
enhance:speedup xinference embedding & rerank (#3587)
This commit is contained in:
parent
b4d2d635f7
commit
4365843c20
|
@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel):
|
|||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
response = xinference_client.rerank(
|
||||
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
|
@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel):
|
|||
try:
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
self.invoke(
|
||||
model=model,
|
||||
|
@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel):
|
|||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
|
|
@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||
|
||||
try:
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
||||
embeddings = handle.create_embedding(input=texts)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(e)
|
||||
|
@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||
|
||||
if extra_args.max_tokens:
|
||||
credentials['max_tokens'] = extra_args.max_tokens
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except InvokeAuthorizationError as e:
|
||||
|
@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
|
Loading…
Reference in New Issue
Block a user