feat: add xinference llm context size (#2336)

This commit is contained in:
Yeuoly 2024-02-01 17:10:45 +08:00 committed by GitHub
parent cfbb7bec58
commit 0c330fc020
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -75,6 +75,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if extra_param.support_function_call: if extra_param.support_function_call:
credentials['support_function_call'] = True credentials['support_function_call'] = True
if extra_param.context_length:
credentials['context_length'] = extra_param.context_length
except RuntimeError as e: except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
except KeyError as e: except KeyError as e:
@ -296,6 +299,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False) support_function_call = credentials.get('support_function_call', False)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
@ -309,6 +313,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
] if support_function_call else [], ] if support_function_call else [],
model_properties={ model_properties={
ModelPropertyKey.MODE: completion_type, ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
}, },
parameter_rules=rules parameter_rules=rules
) )

View File

@ -14,15 +14,17 @@ class XinferenceModelExtraParameter(object):
model_handle_type: str model_handle_type: str
model_ability: List[str] model_ability: List[str]
max_tokens: int = 512 max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False support_function_call: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
support_function_call: bool, max_tokens: int) -> None: support_function_call: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format self.model_format = model_format
self.model_handle_type = model_handle_type self.model_handle_type = model_handle_type
self.model_ability = model_ability self.model_ability = model_ability
self.support_function_call = support_function_call self.support_function_call = support_function_call
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.context_length = context_length
cache = {} cache = {}
cache_lock = Lock() cache_lock = Lock()
@ -57,7 +59,7 @@ class XinferenceHelper:
url = path.join(server_url, 'v1/models', model_uid) url = path.join(server_url, 'v1/models', model_uid)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session() session = Session()
session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3))
@ -88,11 +90,14 @@ class XinferenceHelper:
support_function_call = 'tools' in model_ability support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512) max_tokens = response_json.get('max_tokens', 512)
context_length = response_json.get('context_length', 2048)
return XinferenceModelExtraParameter( return XinferenceModelExtraParameter(
model_format=model_format, model_format=model_format,
model_handle_type=model_handle_type, model_handle_type=model_handle_type,
model_ability=model_ability, model_ability=model_ability,
support_function_call=support_function_call, support_function_call=support_function_call,
max_tokens=max_tokens max_tokens=max_tokens,
context_length=context_length
) )