diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 99876b2f5e..085bac8601 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -62,6 +62,8 @@ class CotAgentOutputParser: thought_str = "thought:" thought_idx = 0 + last_character = "" + for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage @@ -74,35 +76,38 @@ class CotAgentOutputParser: while index < len(response): steps = 1 delta = response[index : index + steps] - last_character = response[index - 1] if index > 0 else "" + yield_delta = False if delta == "`": + last_character = delta code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: + last_character = delta yield code_block_cache code_block_cache = "" else: + last_character = delta code_block_cache += delta code_block_delimiter_count = 0 if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: if last_character not in {"\n", " ", ""}: + yield_delta = True + else: + last_character = delta + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = "" + action_idx = 0 index += steps - yield delta continue - - action_cache += delta - action_idx += 1 - if action_idx == len(action_str): - action_cache = "" - action_idx = 0 - index += steps - continue elif delta.lower() == action_str[action_idx] and action_idx > 0: + last_character = delta action_cache += delta action_idx += 1 if action_idx == len(action_str): @@ -112,24 +117,25 @@ class CotAgentOutputParser: continue else: if action_cache: + last_character = delta yield action_cache action_cache = "" action_idx = 0 if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if last_character not in {"\n", " ", ""}: + yield_delta = True + else: + last_character = delta + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = "" + thought_idx = 0 index += steps - yield delta continue - - thought_cache += delta - thought_idx += 1 - if thought_idx == len(thought_str): - thought_cache = "" - thought_idx = 0 - index += steps - continue elif delta.lower() == thought_str[thought_idx] and thought_idx > 0: + last_character = delta thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): @@ -139,12 +145,20 @@ class CotAgentOutputParser: continue else: if thought_cache: + last_character = delta yield thought_cache thought_cache = "" thought_idx = 0 + if yield_delta: + index += steps + last_character = delta + yield delta + continue + if code_block_delimiter_count == 3: if in_code_block: + last_character = delta yield from extra_json_from_code_block(code_block_cache) code_block_cache = "" @@ -156,8 +170,10 @@ class CotAgentOutputParser: if delta == "{": json_quote_count += 1 in_json = True + last_character = delta json_cache += delta elif delta == "}": + last_character = delta json_cache += delta if json_quote_count > 0: json_quote_count -= 1 @@ -168,16 +184,19 @@ class CotAgentOutputParser: continue else: if in_json: + last_character = delta json_cache += delta if got_json: got_json = False + last_character = delta yield parse_action(json_cache) json_cache = "" json_quote_count = 0 in_json = False if not in_code_block and not in_json: + last_character = delta yield delta.replace("`", "") index += steps