Ver código fonte

fix: In the output, the order of 'ta' is sometimes reversed as 'at'. #8015 (#8791)

Wei-shun Bao 1 ano atrás
pai
commit
fb32e5ca9a
1 arquivos alterados com 38 adições e 19 exclusões
  1. 38 19
      api/core/agent/output_parser/cot_output_parser.py

+ 38 - 19
api/core/agent/output_parser/cot_output_parser.py

@@ -62,6 +62,8 @@ class CotAgentOutputParser:
         thought_str = "thought:"
         thought_idx = 0
 
+        last_character = ""
+
         for response in llm_response:
             if response.delta.usage:
                 usage_dict["usage"] = response.delta.usage
@@ -74,35 +76,38 @@ class CotAgentOutputParser:
             while index < len(response):
                 steps = 1
                 delta = response[index : index + steps]
-                last_character = response[index - 1] if index > 0 else ""
+                yield_delta = False
 
                 if delta == "`":
+                    last_character = delta
                     code_block_cache += delta
                     code_block_delimiter_count += 1
                 else:
                     if not in_code_block:
                         if code_block_delimiter_count > 0:
+                            last_character = delta
                             yield code_block_cache
                         code_block_cache = ""
                     else:
+                        last_character = delta
                         code_block_cache += delta
                     code_block_delimiter_count = 0
 
                 if not in_code_block and not in_json:
                     if delta.lower() == action_str[action_idx] and action_idx == 0:
                         if last_character not in {"\n", " ", ""}:
+                            yield_delta = True
+                        else:
+                            last_character = delta
+                            action_cache += delta
+                            action_idx += 1
+                            if action_idx == len(action_str):
+                                action_cache = ""
+                                action_idx = 0
                             index += steps
-                            yield delta
                             continue
-
-                        action_cache += delta
-                        action_idx += 1
-                        if action_idx == len(action_str):
-                            action_cache = ""
-                            action_idx = 0
-                        index += steps
-                        continue
                     elif delta.lower() == action_str[action_idx] and action_idx > 0:
+                        last_character = delta
                         action_cache += delta
                         action_idx += 1
                         if action_idx == len(action_str):
@@ -112,24 +117,25 @@ class CotAgentOutputParser:
                         continue
                     else:
                         if action_cache:
+                            last_character = delta
                             yield action_cache
                             action_cache = ""
                             action_idx = 0
 
                     if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
                         if last_character not in {"\n", " ", ""}:
+                            yield_delta = True
+                        else:
+                            last_character = delta
+                            thought_cache += delta
+                            thought_idx += 1
+                            if thought_idx == len(thought_str):
+                                thought_cache = ""
+                                thought_idx = 0
                             index += steps
-                            yield delta
                             continue
-
-                        thought_cache += delta
-                        thought_idx += 1
-                        if thought_idx == len(thought_str):
-                            thought_cache = ""
-                            thought_idx = 0
-                        index += steps
-                        continue
                     elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
+                        last_character = delta
                         thought_cache += delta
                         thought_idx += 1
                         if thought_idx == len(thought_str):
@@ -139,12 +145,20 @@ class CotAgentOutputParser:
                         continue
                     else:
                         if thought_cache:
+                            last_character = delta
                             yield thought_cache
                             thought_cache = ""
                             thought_idx = 0
 
+                    if yield_delta:
+                        index += steps
+                        last_character = delta
+                        yield delta
+                        continue
+
                 if code_block_delimiter_count == 3:
                     if in_code_block:
+                        last_character = delta
                         yield from extra_json_from_code_block(code_block_cache)
                         code_block_cache = ""
 
@@ -156,8 +170,10 @@ class CotAgentOutputParser:
                     if delta == "{":
                         json_quote_count += 1
                         in_json = True
+                        last_character = delta
                         json_cache += delta
                     elif delta == "}":
+                        last_character = delta
                         json_cache += delta
                         if json_quote_count > 0:
                             json_quote_count -= 1
@@ -168,16 +184,19 @@ class CotAgentOutputParser:
                                 continue
                     else:
                         if in_json:
+                            last_character = delta
                             json_cache += delta
 
                     if got_json:
                         got_json = False
+                        last_character = delta
                         yield parse_action(json_cache)
                         json_cache = ""
                         json_quote_count = 0
                         in_json = False
 
                 if not in_code_block and not in_json:
+                    last_character = delta
                     yield delta.replace("`", "")
 
                 index += steps