mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(parameter_extractor): implement custom error classes (#10260)
This commit is contained in:
parent
65a04ee0be
commit
c5422af400
50
api/core/workflow/nodes/parameter_extractor/exc.py
Normal file
50
api/core/workflow/nodes/parameter_extractor/exc.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
class ParameterExtractorNodeError(ValueError):
|
||||||
|
"""Base error for ParameterExtractorNode."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelTypeError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the model is not a Large Language Model."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the model schema is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInvokeResultError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the invoke result is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTextContentTypeError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the text content type is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the number of parameters is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredParameterMissingError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a required parameter is missing."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSelectValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a select value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidNumberValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a number value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidBoolValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a bool value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidStringValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a string value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when an array value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the model mode is invalid."""
|
|
@ -32,6 +32,21 @@ from extensions.ext_database import db
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
from .entities import ParameterExtractorNodeData
|
from .entities import ParameterExtractorNodeData
|
||||||
|
from .exc import (
|
||||||
|
InvalidArrayValueError,
|
||||||
|
InvalidBoolValueError,
|
||||||
|
InvalidInvokeResultError,
|
||||||
|
InvalidModelModeError,
|
||||||
|
InvalidModelTypeError,
|
||||||
|
InvalidNumberOfParametersError,
|
||||||
|
InvalidNumberValueError,
|
||||||
|
InvalidSelectValueError,
|
||||||
|
InvalidStringValueError,
|
||||||
|
InvalidTextContentTypeError,
|
||||||
|
ModelSchemaNotFoundError,
|
||||||
|
ParameterExtractorNodeError,
|
||||||
|
RequiredParameterMissingError,
|
||||||
|
)
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
CHAT_EXAMPLE,
|
CHAT_EXAMPLE,
|
||||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||||
|
@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
|
|
||||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||||
raise ValueError("Model is not a Large Language Model")
|
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||||
|
|
||||||
llm_model = model_instance.model_type_instance
|
llm_model = model_instance.model_type_instance
|
||||||
model_schema = llm_model.get_model_schema(
|
model_schema = llm_model.get_model_schema(
|
||||||
|
@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
credentials=model_config.credentials,
|
credentials=model_config.credentials,
|
||||||
)
|
)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError("Model schema not found")
|
raise ModelSchemaNotFoundError("Model schema not found")
|
||||||
|
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = self._fetch_memory(
|
memory = self._fetch_memory(
|
||||||
|
@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
process_data["usage"] = jsonable_encoder(usage)
|
process_data["usage"] = jsonable_encoder(usage)
|
||||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||||
process_data["llm_text"] = text
|
process_data["llm_text"] = text
|
||||||
except Exception as e:
|
except ParameterExtractorNodeError as e:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self._validate_result(data=node_data, result=result or {})
|
result = self._validate_result(data=node_data, result=result or {})
|
||||||
except Exception as e:
|
except ParameterExtractorNodeError as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
|
|
||||||
# transform result into standard format
|
# transform result into standard format
|
||||||
|
@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
if not isinstance(invoke_result, LLMResult):
|
if not isinstance(invoke_result, LLMResult):
|
||||||
raise ValueError(f"Invalid invoke result: {invoke_result}")
|
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||||
|
|
||||||
text = invoke_result.message.content
|
text = invoke_result.message.content
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
|
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||||
|
|
||||||
usage = invoke_result.usage
|
usage = invoke_result.usage
|
||||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||||
|
@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
files=files,
|
files=files,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
|
||||||
|
|
||||||
def _generate_prompt_engineering_completion_prompt(
|
def _generate_prompt_engineering_completion_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
|
||||||
Validate result.
|
Validate result.
|
||||||
"""
|
"""
|
||||||
if len(data.parameters) != len(result):
|
if len(data.parameters) != len(result):
|
||||||
raise ValueError("Invalid number of parameters")
|
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||||
|
|
||||||
for parameter in data.parameters:
|
for parameter in data.parameters:
|
||||||
if parameter.required and parameter.name not in result:
|
if parameter.required and parameter.name not in result:
|
||||||
raise ValueError(f"Parameter {parameter.name} is required")
|
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||||
|
|
||||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||||
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
|
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||||
|
|
||||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||||
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
|
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||||
|
|
||||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||||
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||||
|
|
||||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||||
|
|
||||||
if parameter.type.startswith("array"):
|
if parameter.type.startswith("array"):
|
||||||
parameters = result.get(parameter.name)
|
parameters = result.get(parameter.name)
|
||||||
if not isinstance(parameters, list):
|
if not isinstance(parameters, list):
|
||||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||||
nested_type = parameter.type[6:-1]
|
nested_type = parameter.type[6:-1]
|
||||||
for item in parameters:
|
for item in parameters:
|
||||||
if nested_type == "number" and not isinstance(item, int | float):
|
if nested_type == "number" and not isinstance(item, int | float):
|
||||||
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||||
if nested_type == "string" and not isinstance(item, str):
|
if nested_type == "string" and not isinstance(item, str):
|
||||||
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||||
if nested_type == "object" and not isinstance(item, dict):
|
if nested_type == "object" and not isinstance(item, dict):
|
||||||
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||||
|
@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||||
return [system_prompt_messages, user_prompt_message]
|
return [system_prompt_messages, user_prompt_message]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model mode {model_mode} not support.")
|
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||||
|
|
||||||
def _get_prompt_engineering_prompt_template(
|
def _get_prompt_engineering_prompt_template(
|
||||||
self,
|
self,
|
||||||
|
@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
|
||||||
.replace("}γγγ", "")
|
.replace("}γγγ", "")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model mode {model_mode} not support.")
|
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||||
|
|
||||||
def _calculate_rest_token(
|
def _calculate_rest_token(
|
||||||
self,
|
self,
|
||||||
|
@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
|
||||||
|
|
||||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||||
raise ValueError("Model is not a Large Language Model")
|
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||||
|
|
||||||
llm_model = model_instance.model_type_instance
|
llm_model = model_instance.model_type_instance
|
||||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError("Model schema not found")
|
raise ModelSchemaNotFoundError("Model schema not found")
|
||||||
|
|
||||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user