From 70c2ec8ed501eefc42b0044853f709f7ea6785ac Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 12 Nov 2024 21:51:09 +0800 Subject: [PATCH] feat(variable-handling): enhance variable and segment conversion (#10483) --- api/core/variables/__init__.py | 2 + api/core/variables/variables.py | 13 ++- api/core/workflow/entities/variable_pool.py | 9 ++- api/factories/variable_factory.py | 80 ++++++++++++++++--- .../core/app/segments/test_segment.py | 5 +- 5 files changed, 91 insertions(+), 18 deletions(-) diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 87f9e3ed45..144c1b899f 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -17,6 +17,7 @@ from .segments import ( from .types import SegmentType from .variables import ( ArrayAnyVariable, + ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -58,4 +59,5 @@ __all__ = [ "ArrayStringSegment", "FileSegment", "FileVariable", + "ArrayFileVariable", ] diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index ddc6914192..c902303eef 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,9 +1,13 @@ +from collections.abc import Sequence +from uuid import uuid4 + from pydantic import Field from core.helper import encrypter from .segments import ( ArrayAnySegment, + ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, @@ -24,11 +28,12 @@ class Variable(Segment): """ id: str = Field( - default="", - description="Unique identity for variable. It's only used by environment variables now.", + default=lambda _: str(uuid4()), + description="Unique identity for variable.", ) name: str description: str = Field(default="", description="Description of the variable.") + selector: Sequence[str] = Field(default_factory=list) class StringVariable(StringSegment, Variable): @@ -78,3 +83,7 @@ class NoneVariable(NoneSegment, Variable): class FileVariable(FileSegment, Variable): pass + + +class ArrayFileVariable(ArrayFileSegment, Variable): + pass diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 3dc3395da1..844b46f352 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -95,13 +95,16 @@ class VariablePool(BaseModel): if len(selector) < 2: raise ValueError("Invalid selector") + if isinstance(value, Variable): + variable = value if isinstance(value, Segment): - v = value + variable = variable_factory.segment_to_variable(segment=value, selector=selector) else: - v = variable_factory.build_segment(value) + segment = variable_factory.build_segment(value) + variable = variable_factory.segment_to_variable(segment=segment, selector=selector) hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]][hash_key] = v + self.variable_dictionary[selector[0]][hash_key] = variable def get(self, selector: Sequence[str], /) -> Segment | None: """ diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 0191102b90..5b004405b4 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,34 +1,65 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any +from uuid import uuid4 from configs import dify_config from core.file import File -from core.variables import ( +from core.variables.exc import VariableError +from core.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, - ArrayNumberVariable, ArrayObjectSegment, - ArrayObjectVariable, ArraySegment, ArrayStringSegment, - ArrayStringVariable, FileSegment, FloatSegment, - FloatVariable, IntegerSegment, - IntegerVariable, NoneSegment, ObjectSegment, + Segment, + StringSegment, +) +from core.variables.types import SegmentType +from core.variables.variables import ( + ArrayAnyVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, ObjectVariable, SecretVariable, - Segment, - SegmentType, - StringSegment, StringVariable, Variable, ) -from core.variables.exc import VariableError + + +class InvalidSelectorError(ValueError): + pass + + +class UnsupportedSegmentTypeError(Exception): + pass + + +# Define the constant +SEGMENT_TO_VARIABLE_MAP = { + StringSegment: StringVariable, + IntegerSegment: IntegerVariable, + FloatSegment: FloatVariable, + ObjectSegment: ObjectVariable, + FileSegment: FileVariable, + ArrayStringSegment: ArrayStringVariable, + ArrayNumberSegment: ArrayNumberVariable, + ArrayObjectSegment: ArrayObjectVariable, + ArrayFileSegment: ArrayFileVariable, + ArrayAnySegment: ArrayAnyVariable, + NoneSegment: NoneVariable, +} def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: @@ -96,3 +127,30 @@ def build_segment(value: Any, /) -> Segment: case _: raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") + + +def segment_to_variable( + *, + segment: Segment, + selector: Sequence[str], + id: str | None = None, + name: str | None = None, + description: str = "", +) -> Variable: + if isinstance(segment, Variable): + return segment + name = name or selector[-1] + id = id or str(uuid4()) + + segment_type = type(segment) + if segment_type not in SEGMENT_TO_VARIABLE_MAP: + raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") + + variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] + return variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 3b1715ab45..1b035d01a7 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,5 +1,5 @@ from core.helper import encrypter -from core.variables import SecretVariable, StringSegment +from core.variables import SecretVariable, StringVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -54,4 +54,5 @@ def test_convert_variable_to_segment_group(): segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" assert segments_group.log == "fake-user-id" - assert segments_group.value == [StringSegment(value="fake-user-id")] + assert isinstance(segments_group.value[0], StringVariable) + assert segments_group.value[0].value == "fake-user-id"