refactor(code_executor): update input type annotations to use Mapping for better type safety (#10478)

This commit is contained in:
-LAN- 2024-11-11 13:10:39 +08:00 committed by GitHub
parent fbee41f8c7
commit b8b6cd409a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 5 deletions

View File

@ -0,0 +1,3 @@
from .code_executor import CodeExecutor, CodeLanguage
__all__ = ["CodeExecutor", "CodeLanguage"]

View File

@ -1,7 +1,8 @@
import logging
from collections.abc import Mapping
from enum import Enum
from threading import Lock
from typing import Optional
from typing import Any, Optional
from httpx import Timeout, post
from pydantic import BaseModel
@ -117,7 +118,7 @@ class CodeExecutor:
return response.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
"""
Execute code
:param language: code language

View File

@ -2,6 +2,8 @@ import json
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from collections.abc import Mapping
from typing import Any
class TemplateTransformer(ABC):
@ -10,7 +12,7 @@ class TemplateTransformer(ABC):
_result_tag: str = "<<RESULT>>"
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""
Transform code to python runner
:param code: code
@ -48,13 +50,13 @@ class TemplateTransformer(ABC):
pass
@classmethod
def serialize_inputs(cls, inputs: dict) -> str:
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
@classmethod
def assemble_runner_script(cls, code: str, inputs: dict) -> str:
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
# assemble runner script
script = cls.get_runner_script()
script = script.replace(cls._code_placeholder, code)