feat(variable-handling): enhance variable and segment conversion (#10483)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions

This commit is contained in:
-LAN- 2024-11-12 21:51:09 +08:00 committed by GitHub
parent 9c7edb9242
commit 70c2ec8ed5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 18 deletions

View File

@ -17,6 +17,7 @@ from .segments import (
from .types import SegmentType from .types import SegmentType
from .variables import ( from .variables import (
ArrayAnyVariable, ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable, ArrayNumberVariable,
ArrayObjectVariable, ArrayObjectVariable,
ArrayStringVariable, ArrayStringVariable,
@ -58,4 +59,5 @@ __all__ = [
"ArrayStringSegment", "ArrayStringSegment",
"FileSegment", "FileSegment",
"FileVariable", "FileVariable",
"ArrayFileVariable",
] ]

View File

@ -1,9 +1,13 @@
from collections.abc import Sequence
from uuid import uuid4
from pydantic import Field from pydantic import Field
from core.helper import encrypter from core.helper import encrypter
from .segments import ( from .segments import (
ArrayAnySegment, ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment, ArrayNumberSegment,
ArrayObjectSegment, ArrayObjectSegment,
ArrayStringSegment, ArrayStringSegment,
@ -24,11 +28,12 @@ class Variable(Segment):
""" """
id: str = Field( id: str = Field(
default="", default=lambda _: str(uuid4()),
description="Unique identity for variable. It's only used by environment variables now.", description="Unique identity for variable.",
) )
name: str name: str
description: str = Field(default="", description="Description of the variable.") description: str = Field(default="", description="Description of the variable.")
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable): class StringVariable(StringSegment, Variable):
@ -78,3 +83,7 @@ class NoneVariable(NoneSegment, Variable):
class FileVariable(FileSegment, Variable): class FileVariable(FileSegment, Variable):
pass pass
class ArrayFileVariable(ArrayFileSegment, Variable):
pass

View File

@ -95,13 +95,16 @@ class VariablePool(BaseModel):
if len(selector) < 2: if len(selector) < 2:
raise ValueError("Invalid selector") raise ValueError("Invalid selector")
if isinstance(value, Variable):
variable = value
if isinstance(value, Segment): if isinstance(value, Segment):
v = value variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else: else:
v = variable_factory.build_segment(value) segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v self.variable_dictionary[selector[0]][hash_key] = variable
def get(self, selector: Sequence[str], /) -> Segment | None: def get(self, selector: Sequence[str], /) -> Segment | None:
""" """

View File

@ -1,34 +1,65 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import Any from typing import Any
from uuid import uuid4
from configs import dify_config from configs import dify_config
from core.file import File from core.file import File
from core.variables import ( from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment, ArrayAnySegment,
ArrayFileSegment, ArrayFileSegment,
ArrayNumberSegment, ArrayNumberSegment,
ArrayNumberVariable,
ArrayObjectSegment, ArrayObjectSegment,
ArrayObjectVariable,
ArraySegment, ArraySegment,
ArrayStringSegment, ArrayStringSegment,
ArrayStringVariable,
FileSegment, FileSegment,
FloatSegment, FloatSegment,
FloatVariable,
IntegerSegment, IntegerSegment,
IntegerVariable,
NoneSegment, NoneSegment,
ObjectSegment, ObjectSegment,
Segment,
StringSegment,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable, ObjectVariable,
SecretVariable, SecretVariable,
Segment,
SegmentType,
StringSegment,
StringVariable, StringVariable,
Variable, Variable,
) )
from core.variables.exc import VariableError
class InvalidSelectorError(ValueError):
pass
class UnsupportedSegmentTypeError(Exception):
pass
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
IntegerSegment: IntegerVariable,
FloatSegment: FloatVariable,
ObjectSegment: ObjectVariable,
FileSegment: FileVariable,
ArrayStringSegment: ArrayStringVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayAnySegment: ArrayAnyVariable,
NoneSegment: NoneVariable,
}
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
@ -96,3 +127,30 @@ def build_segment(value: Any, /) -> Segment:
case _: case _:
raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}")
def segment_to_variable(
*,
segment: Segment,
selector: Sequence[str],
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
return segment
name = name or selector[-1]
id = id or str(uuid4())
segment_type = type(segment)
if segment_type not in SEGMENT_TO_VARIABLE_MAP:
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=selector,
)

View File

@ -1,5 +1,5 @@
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, StringSegment from core.variables import SecretVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -54,4 +54,5 @@ def test_convert_variable_to_segment_group():
segments_group = variable_pool.convert_template(template) segments_group = variable_pool.convert_template(template)
assert segments_group.text == "fake-user-id" assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id" assert segments_group.log == "fake-user-id"
assert segments_group.value == [StringSegment(value="fake-user-id")] assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"