refactor: update stop parameter type to use Sequence instead of list

This commit is contained in:
-LAN- 2024-11-13 16:21:29 +08:00
parent 1fdaea29aa
commit 4f89214d89
4 changed files with 17 additions and 16 deletions

View File

@ -103,7 +103,7 @@ class ModelInstance:
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@ -2,7 +2,7 @@ import logging
import re
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union
from pydantic import ConfigDict
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -212,7 +212,7 @@ if you are not sure about the structure.
)
model_parameters.pop("response_format")
stop = stop or []
stop = list(stop) if stop is not None else []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
@ -408,7 +408,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -479,7 +479,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -601,7 +601,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -647,7 +647,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -694,7 +694,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -742,7 +742,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,

View File

@ -204,7 +204,7 @@ class LLMNode(BaseNode[LLMNodeData]):
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()