mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: support comfyui workflow tool image generate image (#9871)
This commit is contained in:
parent
eec63b112f
commit
ace7ffab5f
|
@ -1,3 +1,5 @@
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -6,45 +8,48 @@ import httpx
|
||||||
from websocket import WebSocket
|
from websocket import WebSocket
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
|
from core.file.file_manager import _get_encoded_string
|
||||||
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class ComfyUiClient:
|
class ComfyUiClient:
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
self.base_url = URL(base_url)
|
self.base_url = URL(base_url)
|
||||||
|
|
||||||
def get_history(self, prompt_id: str):
|
def get_history(self, prompt_id: str) -> dict:
|
||||||
res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
|
res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
|
||||||
history = res.json()[prompt_id]
|
history = res.json()[prompt_id]
|
||||||
return history
|
return history
|
||||||
|
|
||||||
def get_image(self, filename: str, subfolder: str, folder_type: str):
|
def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
||||||
response = httpx.get(
|
response = httpx.get(
|
||||||
str(self.base_url / "view"),
|
str(self.base_url / "view"),
|
||||||
params={"filename": filename, "subfolder": subfolder, "type": folder_type},
|
params={"filename": filename, "subfolder": subfolder, "type": folder_type},
|
||||||
)
|
)
|
||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False):
|
def upload_image(self, image_file: File) -> dict:
|
||||||
# plan to support img2img in dify 0.10.0
|
image_content = base64.b64decode(_get_encoded_string(image_file))
|
||||||
with open(input_path, "rb") as file:
|
file = io.BytesIO(image_content)
|
||||||
files = {"image": (name, file, "image/png")}
|
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
|
||||||
data = {"type": image_type, "overwrite": str(overwrite).lower()}
|
res = httpx.post(str(self.base_url / "upload/image"), files=files)
|
||||||
|
return res.json()
|
||||||
|
|
||||||
res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files)
|
def queue_prompt(self, client_id: str, prompt: dict) -> str:
|
||||||
return res
|
|
||||||
|
|
||||||
def queue_prompt(self, client_id: str, prompt: dict):
|
|
||||||
res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
|
res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
|
||||||
prompt_id = res.json()["prompt_id"]
|
prompt_id = res.json()["prompt_id"]
|
||||||
return prompt_id
|
return prompt_id
|
||||||
|
|
||||||
def open_websocket_connection(self):
|
def open_websocket_connection(self) -> tuple[WebSocket, str]:
|
||||||
client_id = str(uuid.uuid4())
|
client_id = str(uuid.uuid4())
|
||||||
ws = WebSocket()
|
ws = WebSocket()
|
||||||
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
||||||
ws.connect(ws_address)
|
ws.connect(ws_address)
|
||||||
return ws, client_id
|
return ws, client_id
|
||||||
|
|
||||||
def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""):
|
def set_prompt(
|
||||||
|
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
find the first KSampler, then can find the prompt node through it.
|
find the first KSampler, then can find the prompt node through it.
|
||||||
"""
|
"""
|
||||||
|
@ -58,6 +63,10 @@ class ComfyUiClient:
|
||||||
if negative_prompt != "":
|
if negative_prompt != "":
|
||||||
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
||||||
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
||||||
|
|
||||||
|
if image_name != "":
|
||||||
|
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
|
||||||
|
prompt.get(image_loader)["inputs"]["image"] = image_name
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
||||||
|
@ -89,7 +98,7 @@ class ComfyUiClient:
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def generate_image_by_prompt(self, prompt: dict):
|
def generate_image_by_prompt(self, prompt: dict) -> list[bytes]:
|
||||||
try:
|
try:
|
||||||
ws, client_id = self.open_websocket_connection()
|
ws, client_id = self.open_websocket_connection()
|
||||||
prompt_id = self.queue_prompt(client_id, prompt)
|
prompt_id = self.queue_prompt(client_id, prompt)
|
||||||
|
|
|
@ -2,10 +2,9 @@ import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
from .comfyui_client import ComfyUiClient
|
|
||||||
|
|
||||||
|
|
||||||
class ComfyUIWorkflowTool(BuiltinTool):
|
class ComfyUIWorkflowTool(BuiltinTool):
|
||||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||||
|
@ -14,13 +13,16 @@ class ComfyUIWorkflowTool(BuiltinTool):
|
||||||
positive_prompt = tool_parameters.get("positive_prompt")
|
positive_prompt = tool_parameters.get("positive_prompt")
|
||||||
negative_prompt = tool_parameters.get("negative_prompt")
|
negative_prompt = tool_parameters.get("negative_prompt")
|
||||||
workflow = tool_parameters.get("workflow_json")
|
workflow = tool_parameters.get("workflow_json")
|
||||||
|
image_name = ""
|
||||||
|
if image := tool_parameters.get("image"):
|
||||||
|
image_name = comfyui.upload_image(image).get("name")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
origin_prompt = json.loads(workflow)
|
origin_prompt = json.loads(workflow)
|
||||||
except:
|
except:
|
||||||
return self.create_text_message("the Workflow JSON is not correct")
|
return self.create_text_message("the Workflow JSON is not correct")
|
||||||
|
|
||||||
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt)
|
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
|
||||||
images = comfyui.generate_image_by_prompt(prompt)
|
images = comfyui.generate_image_by_prompt(prompt)
|
||||||
result = []
|
result = []
|
||||||
for img in images:
|
for img in images:
|
||||||
|
|
|
@ -24,6 +24,13 @@ parameters:
|
||||||
zh_Hans: 负面提示词
|
zh_Hans: 负面提示词
|
||||||
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||||
form: llm
|
form: llm
|
||||||
|
- name: image
|
||||||
|
type: file
|
||||||
|
label:
|
||||||
|
en_US: Input Image
|
||||||
|
zh_Hans: 输入的图片
|
||||||
|
llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
|
||||||
|
form: llm
|
||||||
- name: workflow_json
|
- name: workflow_json
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: true
|
||||||
|
|
Loading…
Reference in New Issue
Block a user