Compare commits

...

2 Commits

Author SHA1 Message Date
-LAN-
fda47da6dc
Merge 63822bb7b2 into 4b2abf8ac2 2024-11-15 10:53:24 +08:00
-LAN-
63822bb7b2 feat(http_request): enhance error handling and response logging
- Introduced `ResponseNotSentError` to handle cases where logs are generated without a sent response.
- Refactored `to_log` method to simplify HTTP request logging and enhance security by masking authorization headers.
2024-11-11 22:02:58 +08:00
3 changed files with 33 additions and 60 deletions

View File

@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError):
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""
class ResponseNotSentError(HttpRequestNodeError):
"""Raised when the response is not sent and log generation is attempted."""

View File

@ -3,7 +3,6 @@ from collections.abc import Mapping
from copy import deepcopy
from random import randint
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
import httpx
@ -22,6 +21,7 @@ from .exc import (
AuthorizationConfigError,
FileFetchError,
InvalidHttpMethodError,
ResponseNotSentError,
ResponseSizeError,
)
@ -46,6 +46,7 @@ class Executor:
timeout: HttpRequestNodeTimeout
boundary: str
response: Response | None = None
def __init__(
self,
@ -218,71 +219,39 @@ class Executor:
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
self.response = self._validate_and_parse_response(response)
return self.response
def to_log(self):
url_parts = urlparse(self.url)
path = url_parts.path or "/"
if self.response is None:
raise ResponseNotSentError("Response not sent, cannot generate log.")
# Add query parameters
if self.params:
query_string = urlencode(self.params)
path += f"?{query_string}"
elif url_parts.query:
path += f"?{url_parts.query}"
response = self.response.response
request = response.request
encoding = response.encoding or "utf-8"
raw = f"{self.method.upper()} {path} HTTP/1.1\r\n"
raw += f"Host: {url_parts.netloc}\r\n"
headers = self._assembling_headers()
body = self.node_data.body
boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
if body:
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
if body.type == "form-data":
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
for k, v in headers.items():
if self.auth.type == "api-key":
authorization_header = "Authorization"
authorization_header = b"authorization"
if self.auth.config and self.auth.config.header:
authorization_header = self.auth.config.header
authorization_header = self.auth.config.header.encode(encoding)
raw = f"{request.method.upper()} {request.url.raw_path.decode(encoding)} {response.http_version}\r\n".encode(
encoding
)
for k, v in request.headers.raw:
if k.lower() == authorization_header.lower():
raw += f'{k}: {"*" * len(v)}\r\n'
raw += k + b": " + b"*" * 16 + b"\r\n"
continue
raw += f"{k}: {v}\r\n"
raw += k + b": " + v + b"\r\n"
body = ""
if self.files:
for k, v in self.files.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body += f"{v[1]}\r\n"
body += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body = self.content
elif isinstance(self.content, bytes):
body = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
for key, value in self.data.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body += f"{value}\r\n"
body += f"--{boundary}--\r\n"
elif self.json:
body = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
body = self.node_data.body.data[0].value
if body:
raw += f"Content-Length: {len(body)}\r\n"
raw += "\r\n" # Empty line between headers and body
raw += body
raw += b"\r\n"
return raw
content = request.read()
raw += content
raw_text = raw.decode(encoding, errors="replace")
if len(raw_text) > 1000:
raw_text = raw_text[:500] + "(......)" + raw_text[-500:]
return raw_text
def _plain_text_to_dict(text: str, /) -> dict[str, str]:

View File

@ -63,9 +63,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
process_data["request"] = http_executor.to_log()
files = self.extract_files(url=http_executor.url, response=response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,