refactor(parameter_extractor): implement custom error classes (#10260)

This commit is contained in:
-LAN- 2024-11-05 09:27:51 +08:00 committed by Joel
parent 65a04ee0be
commit c5422af400
2 changed files with 86 additions and 21 deletions

View File

@ -0,0 +1,50 @@
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
class InvalidModelTypeError(ParameterExtractorNodeError):
"""Raised when the model is not a Large Language Model."""
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
"""Raised when the model schema is not found."""
class InvalidInvokeResultError(ParameterExtractorNodeError):
"""Raised when the invoke result is invalid."""
class InvalidTextContentTypeError(ParameterExtractorNodeError):
"""Raised when the text content type is invalid."""
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
"""Raised when the number of parameters is invalid."""
class RequiredParameterMissingError(ParameterExtractorNodeError):
"""Raised when a required parameter is missing."""
class InvalidSelectValueError(ParameterExtractorNodeError):
"""Raised when a select value is invalid."""
class InvalidNumberValueError(ParameterExtractorNodeError):
"""Raised when a number value is invalid."""
class InvalidBoolValueError(ParameterExtractorNodeError):
"""Raised when a bool value is invalid."""
class InvalidStringValueError(ParameterExtractorNodeError):
"""Raised when a string value is invalid."""
class InvalidArrayValueError(ParameterExtractorNodeError):
"""Raised when an array value is invalid."""
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""

View File

@ -32,6 +32,21 @@ from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import ParameterExtractorNodeData from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
)
from .prompts import ( from .prompts import (
CHAT_EXAMPLE, CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model) model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel): if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model") raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema( model_schema = llm_model.get_model_schema(
@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
credentials=model_config.credentials, credentials=model_config.credentials,
) )
if not model_schema: if not model_schema:
raise ValueError("Model schema not found") raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory # fetch memory
memory = self._fetch_memory( memory = self._fetch_memory(
@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
process_data["usage"] = jsonable_encoder(usage) process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call) process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text process_data["llm_text"] = text
except Exception as e: except ParameterExtractorNodeError as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs, inputs=inputs,
@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
try: try:
result = self._validate_result(data=node_data, result=result or {}) result = self._validate_result(data=node_data, result=result or {})
except Exception as e: except ParameterExtractorNodeError as e:
error = str(e) error = str(e)
# transform result into standard format # transform result into standard format
@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
# handle invoke result # handle invoke result
if not isinstance(invoke_result, LLMResult): if not isinstance(invoke_result, LLMResult):
raise ValueError(f"Invalid invoke result: {invoke_result}") raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content text = invoke_result.message.content
if not isinstance(text, str): if not isinstance(text, str):
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
files=files, files=files,
) )
else: else:
raise ValueError(f"Invalid model mode: {model_mode}") raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt( def _generate_prompt_engineering_completion_prompt(
self, self,
@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
Validate result. Validate result.
""" """
if len(data.parameters) != len(result): if len(data.parameters) != len(result):
raise ValueError("Invalid number of parameters") raise InvalidNumberOfParametersError("Invalid number of parameters")
for parameter in data.parameters: for parameter in data.parameters:
if parameter.required and parameter.name not in result: if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required") raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}") raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}") raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str): if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}") raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"): if parameter.type.startswith("array"):
parameters = result.get(parameter.name) parameters = result.get(parameter.name)
if not isinstance(parameters, list): if not isinstance(parameters, list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}") raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1] nested_type = parameter.type[6:-1]
for item in parameters: for item in parameters:
if nested_type == "number" and not isinstance(item, int | float): if nested_type == "number" and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str): if nested_type == "string" and not isinstance(item, str):
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict): if nested_type == "object" and not isinstance(item, dict):
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
return result return result
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message] return [system_prompt_messages, user_prompt_message]
else: else:
raise ValueError(f"Model mode {model_mode} not support.") raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template( def _get_prompt_engineering_prompt_template(
self, self,
@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
.replace("}γγγ", "") .replace("}γγγ", "")
) )
else: else:
raise ValueError(f"Model mode {model_mode} not support.") raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _calculate_rest_token( def _calculate_rest_token(
self, self,
@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model) model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel): if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model") raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema: if not model_schema:
raise ValueError("Model schema not found") raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)