mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
parent
0b4902bdc2
commit
e0da0744b5
|
@ -8,7 +8,12 @@ from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
|
@ -40,7 +45,9 @@ from core.model_runtime.errors.invoke import (
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -50,11 +57,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
Model class for Ollama large language model.
|
Model class for Ollama large language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
|
@ -75,11 +88,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
@ -100,10 +118,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
if isinstance(first_prompt_message.content, str):
|
if isinstance(first_prompt_message.content, str):
|
||||||
text = first_prompt_message.content
|
text = first_prompt_message.content
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ""
|
||||||
for message_content in first_prompt_message.content:
|
for message_content in first_prompt_message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
text = message_content.data
|
text = message_content.data
|
||||||
break
|
break
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
@ -121,19 +141,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=[UserPromptMessage(content="ping")],
|
prompt_messages=[UserPromptMessage(content="ping")],
|
||||||
model_parameters={
|
model_parameters={"num_predict": 5},
|
||||||
'num_predict': 5
|
stream=False,
|
||||||
},
|
|
||||||
stream=False
|
|
||||||
)
|
)
|
||||||
except InvokeError as ex:
|
except InvokeError as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {ex.description}"
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {str(ex)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm completion model
|
Invoke llm completion model
|
||||||
|
|
||||||
|
@ -146,76 +175,93 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint_url = credentials['base_url']
|
endpoint_url = credentials["base_url"]
|
||||||
if not endpoint_url.endswith('/'):
|
if not endpoint_url.endswith("/"):
|
||||||
endpoint_url += '/'
|
endpoint_url += "/"
|
||||||
|
|
||||||
# prepare the payload for a simple ping to the model
|
# prepare the payload for a simple ping to the model
|
||||||
data = {
|
data = {"model": model, "stream": stream}
|
||||||
'model': model,
|
|
||||||
'stream': stream
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'format' in model_parameters:
|
if "format" in model_parameters:
|
||||||
data['format'] = model_parameters['format']
|
data["format"] = model_parameters["format"]
|
||||||
del model_parameters['format']
|
del model_parameters["format"]
|
||||||
|
|
||||||
data['options'] = model_parameters or {}
|
if "keep_alive" in model_parameters:
|
||||||
|
data["keep_alive"] = model_parameters["keep_alive"]
|
||||||
|
del model_parameters["keep_alive"]
|
||||||
|
|
||||||
|
data["options"] = model_parameters or {}
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
data['stop'] = "\n".join(stop)
|
data["stop"] = "\n".join(stop)
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials['mode'])
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
endpoint_url = urljoin(endpoint_url, 'api/chat')
|
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
data["messages"] = [
|
||||||
|
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
endpoint_url = urljoin(endpoint_url, 'api/generate')
|
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||||
first_prompt_message = prompt_messages[0]
|
first_prompt_message = prompt_messages[0]
|
||||||
if isinstance(first_prompt_message, UserPromptMessage):
|
if isinstance(first_prompt_message, UserPromptMessage):
|
||||||
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
||||||
if isinstance(first_prompt_message.content, str):
|
if isinstance(first_prompt_message.content, str):
|
||||||
data['prompt'] = first_prompt_message.content
|
data["prompt"] = first_prompt_message.content
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ""
|
||||||
images = []
|
images = []
|
||||||
for message_content in first_prompt_message.content:
|
for message_content in first_prompt_message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
text = message_content.data
|
text = message_content.data
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
|
image_data = re.sub(
|
||||||
|
r"^data:image\/[a-zA-Z]+;base64,",
|
||||||
|
"",
|
||||||
|
message_content.data,
|
||||||
|
)
|
||||||
images.append(image_data)
|
images.append(image_data)
|
||||||
|
|
||||||
data['prompt'] = text
|
data["prompt"] = text
|
||||||
data['images'] = images
|
data["images"] = images
|
||||||
|
|
||||||
# send a post request to validate the credentials
|
# send a post request to validate the credentials
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
endpoint_url,
|
endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream
|
||||||
headers=headers,
|
|
||||||
json=data,
|
|
||||||
timeout=(10, 300),
|
|
||||||
stream=stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response.encoding = "utf-8"
|
response.encoding = "utf-8"
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
raise InvokeError(
|
||||||
|
f"API request failed with status code {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, completion_type, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, completion_type, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
def _handle_generate_response(
|
||||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
completion_type: LLMMode,
|
||||||
|
response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm completion response
|
Handle llm completion response
|
||||||
|
|
||||||
|
@ -229,14 +275,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
message = response_json.get('message', {})
|
message = response_json.get("message", {})
|
||||||
response_content = message.get('content', '')
|
response_content = message.get("content", "")
|
||||||
else:
|
else:
|
||||||
response_content = response_json['response']
|
response_content = response_json["response"]
|
||||||
|
|
||||||
assistant_message = AssistantPromptMessage(content=response_content)
|
assistant_message = AssistantPromptMessage(content=response_content)
|
||||||
|
|
||||||
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
|
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||||
# transform usage
|
# transform usage
|
||||||
prompt_tokens = response_json["prompt_eval_count"]
|
prompt_tokens = response_json["prompt_eval_count"]
|
||||||
completion_tokens = response_json["eval_count"]
|
completion_tokens = response_json["eval_count"]
|
||||||
|
@ -246,7 +292,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
|
@ -258,8 +306,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
def _handle_generate_stream_response(
|
||||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
completion_type: LLMMode,
|
||||||
|
response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm completion stream response
|
Handle llm completion stream response
|
||||||
|
|
||||||
|
@ -270,17 +324,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator result
|
:return: llm response chunk generator result
|
||||||
"""
|
"""
|
||||||
full_text = ''
|
full_text = ""
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
|
||||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
def create_final_llm_result_chunk(
|
||||||
-> LLMResultChunk:
|
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||||
|
) -> LLMResultChunk:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResultChunk(
|
return LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -289,11 +346,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
index=index,
|
index=index,
|
||||||
message=message,
|
message=message,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -304,7 +361,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
yield create_final_llm_result_chunk(
|
yield create_final_llm_result_chunk(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=AssistantPromptMessage(content=""),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason="Non-JSON encountered."
|
finish_reason="Non-JSON encountered.",
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
@ -314,55 +371,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
if not chunk_json:
|
if not chunk_json:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'message' not in chunk_json:
|
if "message" not in chunk_json:
|
||||||
text = ''
|
text = ""
|
||||||
else:
|
else:
|
||||||
text = chunk_json.get('message').get('content', '')
|
text = chunk_json.get("message").get("content", "")
|
||||||
else:
|
else:
|
||||||
if not chunk_json:
|
if not chunk_json:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
text = chunk_json['response']
|
text = chunk_json["response"]
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
content=text
|
|
||||||
)
|
|
||||||
|
|
||||||
full_text += text
|
full_text += text
|
||||||
|
|
||||||
if chunk_json['done']:
|
if chunk_json["done"]:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
|
if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json:
|
||||||
# transform usage
|
# transform usage
|
||||||
prompt_tokens = chunk_json["prompt_eval_count"]
|
prompt_tokens = chunk_json["prompt_eval_count"]
|
||||||
completion_tokens = chunk_json["eval_count"]
|
completion_tokens = chunk_json["eval_count"]
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||||
|
prompt_messages[0].content
|
||||||
|
)
|
||||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk_json['model'],
|
model=chunk_json["model"],
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason='stop',
|
finish_reason="stop",
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk_json['model'],
|
model=chunk_json["model"],
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
@ -376,15 +435,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ""
|
||||||
images = []
|
images = []
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
text = message_content.data
|
text = message_content.data
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
|
image_data = re.sub(
|
||||||
|
r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
|
||||||
|
)
|
||||||
images.append(image_data)
|
images.append(image_data)
|
||||||
|
|
||||||
message_dict = {"role": "user", "content": text, "images": images}
|
message_dict = {"role": "user", "content": text, "images": images}
|
||||||
|
@ -414,7 +479,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema.
|
Get customizable model schema.
|
||||||
|
|
||||||
|
@ -425,20 +492,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
extras = {}
|
extras = {}
|
||||||
|
|
||||||
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
|
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||||
extras['features'] = [ModelFeature.VISION]
|
extras["features"] = [ModelFeature.VISION]
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans=model, en_US=model),
|
||||||
zh_Hans=model,
|
|
||||||
en_US=model
|
|
||||||
),
|
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||||
|
credentials.get("context_size", 4096)
|
||||||
|
),
|
||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
|
@ -446,91 +512,111 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||||
label=I18nObject(en_US="Temperature"),
|
label=I18nObject(en_US="Temperature"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="The temperature of the model. "
|
help=I18nObject(
|
||||||
"Increasing the temperature will make the model answer "
|
en_US="The temperature of the model. "
|
||||||
"more creatively. (Default: 0.8)"),
|
"Increasing the temperature will make the model answer "
|
||||||
|
"more creatively. (Default: 0.8)"
|
||||||
|
),
|
||||||
default=0.1,
|
default=0.1,
|
||||||
min=0,
|
min=0,
|
||||||
max=1
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.TOP_P.value,
|
name=DefaultParameterName.TOP_P.value,
|
||||||
use_template=DefaultParameterName.TOP_P.value,
|
use_template=DefaultParameterName.TOP_P.value,
|
||||||
label=I18nObject(en_US="Top P"),
|
label=I18nObject(en_US="Top P"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
help=I18nObject(
|
||||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||||
"focused and conservative text. (Default: 0.9)"),
|
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||||
|
"focused and conservative text. (Default: 0.9)"
|
||||||
|
),
|
||||||
default=0.9,
|
default=0.9,
|
||||||
min=0,
|
min=0,
|
||||||
max=1
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name="top_k",
|
name="top_k",
|
||||||
label=I18nObject(en_US="Top K"),
|
label=I18nObject(en_US="Top K"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
|
help=I18nObject(
|
||||||
"A higher value (e.g. 100) will give more diverse answers, "
|
en_US="Reduces the probability of generating nonsense. "
|
||||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
|
"A higher value (e.g. 100) will give more diverse answers, "
|
||||||
|
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||||
|
),
|
||||||
min=1,
|
min=1,
|
||||||
max=100
|
max=100,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='repeat_penalty',
|
name="repeat_penalty",
|
||||||
label=I18nObject(en_US="Repeat Penalty"),
|
label=I18nObject(en_US="Repeat Penalty"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
|
help=I18nObject(
|
||||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
en_US="Sets how strongly to penalize repetitions. "
|
||||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
|
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||||
|
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||||
|
),
|
||||||
min=-2,
|
min=-2,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='num_predict',
|
name="num_predict",
|
||||||
use_template='max_tokens',
|
use_template="max_tokens",
|
||||||
label=I18nObject(en_US="Num Predict"),
|
label=I18nObject(en_US="Num Predict"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
|
help=I18nObject(
|
||||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
|
en_US="Maximum number of tokens to predict when generating text. "
|
||||||
default=512 if int(credentials.get('max_tokens', 4096)) >= 768 else 128,
|
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||||
|
),
|
||||||
|
default=(
|
||||||
|
512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128
|
||||||
|
),
|
||||||
min=-2,
|
min=-2,
|
||||||
max=int(credentials.get('max_tokens', 4096)),
|
max=int(credentials.get("max_tokens", 4096)),
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='mirostat',
|
name="mirostat",
|
||||||
label=I18nObject(en_US="Mirostat sampling"),
|
label=I18nObject(en_US="Mirostat sampling"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
|
help=I18nObject(
|
||||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
|
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||||
|
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||||
|
),
|
||||||
min=0,
|
min=0,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='mirostat_eta',
|
name="mirostat_eta",
|
||||||
label=I18nObject(en_US="Mirostat Eta"),
|
label=I18nObject(en_US="Mirostat Eta"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
|
help=I18nObject(
|
||||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||||
"while a higher learning rate will make the algorithm more responsive. "
|
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||||
"(Default: 0.1)"),
|
"while a higher learning rate will make the algorithm more responsive. "
|
||||||
precision=1
|
"(Default: 0.1)"
|
||||||
|
),
|
||||||
|
precision=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='mirostat_tau',
|
name="mirostat_tau",
|
||||||
label=I18nObject(en_US="Mirostat Tau"),
|
label=I18nObject(en_US="Mirostat Tau"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
|
help=I18nObject(
|
||||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
|
en_US="Controls the balance between coherence and diversity of the output. "
|
||||||
precision=1
|
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||||
|
),
|
||||||
|
precision=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='num_ctx',
|
name="num_ctx",
|
||||||
label=I18nObject(en_US="Size of context window"),
|
label=I18nObject(en_US="Size of context window"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
|
help=I18nObject(
|
||||||
"(Default: 2048)"),
|
en_US="Sets the size of the context window used to generate the next token. "
|
||||||
|
"(Default: 2048)"
|
||||||
|
),
|
||||||
default=2048,
|
default=2048,
|
||||||
min=1
|
min=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='num_gpu',
|
name='num_gpu',
|
||||||
|
@ -544,56 +630,77 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
default=1
|
default=1
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='num_thread',
|
name="num_thread",
|
||||||
label=I18nObject(en_US="Num Thread"),
|
label=I18nObject(en_US="Num Thread"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets the number of threads to use during computation. "
|
help=I18nObject(
|
||||||
"By default, Ollama will detect this for optimal performance. "
|
en_US="Sets the number of threads to use during computation. "
|
||||||
"It is recommended to set this value to the number of physical CPU cores "
|
"By default, Ollama will detect this for optimal performance. "
|
||||||
"your system has (as opposed to the logical number of cores)."),
|
"It is recommended to set this value to the number of physical CPU cores "
|
||||||
|
"your system has (as opposed to the logical number of cores)."
|
||||||
|
),
|
||||||
min=1,
|
min=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='repeat_last_n',
|
name="repeat_last_n",
|
||||||
label=I18nObject(en_US="Repeat last N"),
|
label=I18nObject(en_US="Repeat last N"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
|
help=I18nObject(
|
||||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
|
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||||
min=-1
|
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||||
|
),
|
||||||
|
min=-1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='tfs_z',
|
name="tfs_z",
|
||||||
label=I18nObject(en_US="TFS Z"),
|
label=I18nObject(en_US="TFS Z"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
help=I18nObject(
|
||||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||||
"while a value of 1.0 disables this setting. (default: 1)"),
|
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||||
precision=1
|
"while a value of 1.0 disables this setting. (default: 1)"
|
||||||
|
),
|
||||||
|
precision=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label=I18nObject(en_US="Seed"),
|
label=I18nObject(en_US="Seed"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
|
help=I18nObject(
|
||||||
"a specific number will make the model generate the same text for "
|
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||||
"the same prompt. (Default: 0)"),
|
"a specific number will make the model generate the same text for "
|
||||||
|
"the same prompt. (Default: 0)"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='format',
|
name="keep_alive",
|
||||||
|
label=I18nObject(en_US="Keep Alive"),
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
help=I18nObject(
|
||||||
|
en_US="Sets how long the model is kept in memory after generating a response. "
|
||||||
|
"This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). "
|
||||||
|
"A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. "
|
||||||
|
"Valid time units are 's','m','h'. (Default: 5m)"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="format",
|
||||||
label=I18nObject(en_US="Format"),
|
label=I18nObject(en_US="Format"),
|
||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
help=I18nObject(en_US="the format to return a response in."
|
help=I18nObject(
|
||||||
" Currently the only accepted value is json."),
|
en_US="the format to return a response in."
|
||||||
options=['json'],
|
" Currently the only accepted value is json."
|
||||||
)
|
),
|
||||||
|
options=["json"],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=Decimal(credentials.get('input_price', 0)),
|
input=Decimal(credentials.get("input_price", 0)),
|
||||||
output=Decimal(credentials.get('output_price', 0)),
|
output=Decimal(credentials.get("output_price", 0)),
|
||||||
unit=Decimal(credentials.get('unit', 0)),
|
unit=Decimal(credentials.get("unit", 0)),
|
||||||
currency=credentials.get('currency', "USD")
|
currency=credentials.get("currency", "USD"),
|
||||||
),
|
),
|
||||||
**extras
|
**extras,
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
@ -621,10 +728,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||||
],
|
],
|
||||||
InvokeServerUnavailableError: [
|
InvokeServerUnavailableError: [
|
||||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||||
requests.exceptions.HTTPError # Server Error
|
requests.exceptions.HTTPError, # Server Error
|
||||||
],
|
],
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [
|
||||||
requests.exceptions.ConnectTimeout, # Timeout
|
requests.exceptions.ConnectTimeout, # Timeout
|
||||||
requests.exceptions.ReadTimeout # Timeout
|
requests.exceptions.ReadTimeout, # Timeout
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user