mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: add xinference llm context size (#2336)
This commit is contained in:
parent
cfbb7bec58
commit
0c330fc020
|
@ -75,6 +75,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
if extra_param.support_function_call:
|
if extra_param.support_function_call:
|
||||||
credentials['support_function_call'] = True
|
credentials['support_function_call'] = True
|
||||||
|
|
||||||
|
if extra_param.context_length:
|
||||||
|
credentials['context_length'] = extra_param.context_length
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
|
@ -296,6 +299,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||||
|
|
||||||
support_function_call = credentials.get('support_function_call', False)
|
support_function_call = credentials.get('support_function_call', False)
|
||||||
|
context_length = credentials.get('context_length', 2048)
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -309,6 +313,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||||
] if support_function_call else [],
|
] if support_function_call else [],
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.MODE: completion_type,
|
ModelPropertyKey.MODE: completion_type,
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||||
},
|
},
|
||||||
parameter_rules=rules
|
parameter_rules=rules
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,15 +14,17 @@ class XinferenceModelExtraParameter(object):
|
||||||
model_handle_type: str
|
model_handle_type: str
|
||||||
model_ability: List[str]
|
model_ability: List[str]
|
||||||
max_tokens: int = 512
|
max_tokens: int = 512
|
||||||
|
context_length: int = 2048
|
||||||
support_function_call: bool = False
|
support_function_call: bool = False
|
||||||
|
|
||||||
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
|
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
|
||||||
support_function_call: bool, max_tokens: int) -> None:
|
support_function_call: bool, max_tokens: int, context_length: int) -> None:
|
||||||
self.model_format = model_format
|
self.model_format = model_format
|
||||||
self.model_handle_type = model_handle_type
|
self.model_handle_type = model_handle_type
|
||||||
self.model_ability = model_ability
|
self.model_ability = model_ability
|
||||||
self.support_function_call = support_function_call
|
self.support_function_call = support_function_call
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.context_length = context_length
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
cache_lock = Lock()
|
cache_lock = Lock()
|
||||||
|
@ -57,7 +59,7 @@ class XinferenceHelper:
|
||||||
|
|
||||||
url = path.join(server_url, 'v1/models', model_uid)
|
url = path.join(server_url, 'v1/models', model_uid)
|
||||||
|
|
||||||
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||||
session = Session()
|
session = Session()
|
||||||
session.mount('http://', HTTPAdapter(max_retries=3))
|
session.mount('http://', HTTPAdapter(max_retries=3))
|
||||||
session.mount('https://', HTTPAdapter(max_retries=3))
|
session.mount('https://', HTTPAdapter(max_retries=3))
|
||||||
|
@ -88,11 +90,14 @@ class XinferenceHelper:
|
||||||
|
|
||||||
support_function_call = 'tools' in model_ability
|
support_function_call = 'tools' in model_ability
|
||||||
max_tokens = response_json.get('max_tokens', 512)
|
max_tokens = response_json.get('max_tokens', 512)
|
||||||
|
|
||||||
|
context_length = response_json.get('context_length', 2048)
|
||||||
|
|
||||||
return XinferenceModelExtraParameter(
|
return XinferenceModelExtraParameter(
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
model_handle_type=model_handle_type,
|
model_handle_type=model_handle_type,
|
||||||
model_ability=model_ability,
|
model_ability=model_ability,
|
||||||
support_function_call=support_function_call,
|
support_function_call=support_function_call,
|
||||||
max_tokens=max_tokens
|
max_tokens=max_tokens,
|
||||||
|
context_length=context_length
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user