From 3271e3e8031da577d56e5cdaf647e28d826a934f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 14 May 2024 16:11:12 +0800 Subject: [PATCH] improve the code readability of http_executor node (#4360) --- .../nodes/http_request/http_executor.py | 194 ++++++------------ 1 file changed, 67 insertions(+), 127 deletions(-) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 0b07ad8e82..97cb59d02d 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -14,28 +14,18 @@ from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.http_request.entities import HttpRequestNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', str(1024 * 1024 * 10))) # 10MB +MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10)) # 10MB READABLE_MAX_BINARY_SIZE = f'{MAX_BINARY_SIZE / 1024 / 1024:.2f}MB' -MAX_TEXT_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_TEXT_SIZE', str(1024 * 1024))) # 10MB # 1MB +MAX_TEXT_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_TEXT_SIZE', 1024 * 1024)) # 1MB READABLE_MAX_TEXT_SIZE = f'{MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' + class HttpExecutorResponse: headers: dict[str, str] response: Union[httpx.Response, requests.Response] def __init__(self, response: Union[httpx.Response, requests.Response] = None): - """ - init - """ - headers = {} - if isinstance(response, httpx.Response): - for k, v in response.headers.items(): - headers[k] = v - elif isinstance(response, requests.Response): - for k, v in response.headers.items(): - headers[k] = v - - self.headers = headers + self.headers = response.headers self.response = response @property @@ -45,21 +35,11 @@ class HttpExecutorResponse: """ content_type = self.get_content_type() file_content_types = ['image', 'audio', 'video'] - for v in file_content_types: - if v in content_type: - return True - - return False + + return any(v in content_type for v in file_content_types) def get_content_type(self) -> str: - """ - get content type - """ - for key, val in self.headers.items(): - if key.lower() == 'content-type': - return val - - return '' + return self.headers.get('content-type') def extract_file(self) -> tuple[str, bytes]: """ @@ -67,29 +47,25 @@ class HttpExecutorResponse: """ if self.is_file: return self.get_content_type(), self.body - + return '', b'' - + @property def content(self) -> str: """ get content """ - if isinstance(self.response, httpx.Response): - return self.response.text - elif isinstance(self.response, requests.Response): + if isinstance(self.response, httpx.Response | requests.Response): return self.response.text else: raise ValueError(f'Invalid response type {type(self.response)}') - + @property def body(self) -> bytes: """ get body """ - if isinstance(self.response, httpx.Response): - return self.response.content - elif isinstance(self.response, requests.Response): + if isinstance(self.response, httpx.Response | requests.Response): return self.response.content else: raise ValueError(f'Invalid response type {type(self.response)}') @@ -99,20 +75,18 @@ class HttpExecutorResponse: """ get status code """ - if isinstance(self.response, httpx.Response): - return self.response.status_code - elif isinstance(self.response, requests.Response): + if isinstance(self.response, httpx.Response | requests.Response): return self.response.status_code else: raise ValueError(f'Invalid response type {type(self.response)}') - + @property def size(self) -> int: """ get size """ return len(self.body) - + @property def readable_size(self) -> str: """ @@ -138,10 +112,8 @@ class HttpExecutor: variable_selectors: list[VariableSelector] timeout: HttpRequestNodeData.Timeout - def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, variable_pool: Optional[VariablePool] = None): - """ - init - """ + def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, + variable_pool: Optional[VariablePool] = None): self.server_url = node_data.url self.method = node_data.method self.authorization = node_data.authorization @@ -155,7 +127,8 @@ class HttpExecutor: self.variable_selectors = [] self._init_template(node_data, variable_pool) - def _is_json_body(self, body: HttpRequestNodeData.Body): + @staticmethod + def _is_json_body(body: HttpRequestNodeData.Body): """ check if body is json """ @@ -165,55 +138,46 @@ class HttpExecutor: return True except: return False - + return False + @staticmethod + def _to_dict(convert_item: str, convert_text: str, maxsplit: int = -1): + """ + Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` + :param convert_item: A label for what item to be converted, params, headers or body. + :param convert_text: The string containing key-value pairs separated by '\n'. + :param maxsplit: The maximum number of splits allowed for the ':' character in each key-value pair. Default is -1 (no limit). + :return: A dictionary containing the key-value pairs from the input string. + """ + kv_paris = convert_text.split('\n') + result = {} + for kv in kv_paris: + if not kv.strip(): + continue + + kv = kv.split(':', maxsplit=maxsplit) + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: + raise ValueError(f'Invalid {convert_item} {kv}') + result[k.strip()] = v + return result + def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): - """ - init template - """ - variable_selectors = [] # extract all template in url self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) # extract all template in params params, params_variable_selectors = self._format_template(node_data.params, variable_pool) - - # fill in params - kv_paris = params.split('\n') - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(':') - if len(kv) == 2: - k, v = kv - elif len(kv) == 1: - k, v = kv[0], '' - else: - raise ValueError(f'Invalid params {kv}') - - self.params[k.strip()] = v + self.params = self._to_dict("params", params) # extract all template in headers headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) - - # fill in headers - kv_paris = headers.split('\n') - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(':') - if len(kv) == 2: - k, v = kv - elif len(kv) == 1: - k, v = kv[0], '' - else: - raise ValueError(f'Invalid headers {kv}') - - self.headers[k.strip()] = v.strip() + self.headers = self._to_dict("headers", headers) # extract all template in body body_data_variable_selectors = [] @@ -231,18 +195,7 @@ class HttpExecutor: self.headers['Content-Type'] = 'application/x-www-form-urlencoded' if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: - body = {} - kv_paris = body_data.split('\n') - for kv in kv_paris: - if not kv.strip(): - continue - kv = kv.split(':', 1) - if len(kv) == 2: - body[kv[0].strip()] = kv[1] - elif len(kv) == 1: - body[kv[0].strip()] = '' - else: - raise ValueError(f'Invalid body {kv}') + body = self._to_dict("body", body_data, 1) if node_data.body.type == 'form-data': self.files = { @@ -261,14 +214,14 @@ class HttpExecutor: self.variable_selectors = (server_url_variable_selectors + params_variable_selectors + headers_variable_selectors + body_data_variable_selectors) - + def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) headers = deepcopy(self.headers) or {} if self.authorization.type == 'api-key': if self.authorization.config.api_key is None: raise ValueError('api_key is required') - + if not self.authorization.config.header: authorization.config.header = 'Authorization' @@ -278,9 +231,9 @@ class HttpExecutor: headers[authorization.config.header] = f'Basic {authorization.config.api_key}' elif self.authorization.config.type == 'custom': headers[authorization.config.header] = authorization.config.api_key - + return headers - + def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse: """ validate the response @@ -289,21 +242,22 @@ class HttpExecutor: executor_response = HttpExecutorResponse(response) else: raise ValueError(f'Invalid response type {type(response)}') - + if executor_response.is_file: if executor_response.size > MAX_BINARY_SIZE: - raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.') + raise ValueError( + f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.') else: if executor_response.size > MAX_TEXT_SIZE: - raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.') - + raise ValueError( + f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.') + return executor_response - + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: """ do http request depending on api bundle """ - # do http request kwargs = { 'url': self.server_url, 'headers': headers, @@ -312,25 +266,14 @@ class HttpExecutor: 'follow_redirects': True } - if self.method == 'get': - response = ssrf_proxy.get(**kwargs) - elif self.method == 'post': - response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs) - elif self.method == 'put': - response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs) - elif self.method == 'delete': - response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs) - elif self.method == 'patch': - response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs) - elif self.method == 'head': - response = ssrf_proxy.head(**kwargs) - elif self.method == 'options': - response = ssrf_proxy.options(**kwargs) + if self.method in ('get', 'head', 'options'): + response = getattr(ssrf_proxy, self.method)(**kwargs) + elif self.method in ('post', 'put', 'delete', 'patch'): + response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: raise ValueError(f'Invalid http method {self.method}') - return response - + def invoke(self) -> HttpExecutorResponse: """ invoke http request @@ -343,14 +286,11 @@ class HttpExecutor: # validate response return self._validate_and_parse_response(response) - + def to_raw_request(self, mask_authorization_header: Optional[bool] = True) -> str: """ convert to raw request """ - if mask_authorization_header == None: - mask_authorization_header = True - server_url = self.server_url if self.params: server_url += f'?{urlencode(self.params)}' @@ -365,11 +305,11 @@ class HttpExecutor: authorization_header = 'Authorization' if self.authorization.config and self.authorization.config.header: authorization_header = self.authorization.config.header - + if k.lower() == authorization_header.lower(): raw_request += f'{k}: {"*" * len(v)}\n' continue - + raw_request += f'{k}: {v}\n' raw_request += '\n'