Support multiple code blobs (#128)
This commit is contained in:
		
							parent
							
								
									e1414f6653
								
							
						
					
					
						commit
						067ae9bc90
					
				|  | @ -60,7 +60,7 @@ from .utils import ( | ||||||
|     AgentMaxStepsError, |     AgentMaxStepsError, | ||||||
|     AgentParsingError, |     AgentParsingError, | ||||||
|     console, |     console, | ||||||
|     parse_code_blob, |     parse_code_blobs, | ||||||
|     parse_json_tool_call, |     parse_json_tool_call, | ||||||
|     truncate_content, |     truncate_content, | ||||||
| ) | ) | ||||||
|  | @ -894,7 +894,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|             ) |             ) | ||||||
|             llm_output = self.model( |             llm_output = self.model( | ||||||
|                 self.input_messages, |                 self.input_messages, | ||||||
|                 stop_sequences=["<end_action>", "Observation:"], |                 stop_sequences=["<end_code>", "Observation:"], | ||||||
|                 **additional_args, |                 **additional_args, | ||||||
|             ) |             ) | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|  | @ -920,7 +920,7 @@ class CodeAgent(MultiStepAgent): | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         try: |         try: | ||||||
|             code_action = fix_final_answer_code(parse_code_blob(llm_output)) |             code_action = fix_final_answer_code(parse_code_blobs(llm_output)) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = ( |             error_msg = ( | ||||||
|                 f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." |                 f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." | ||||||
|  |  | ||||||
|  | @ -39,12 +39,12 @@ logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| DEFAULT_JSONAGENT_REGEX_GRAMMAR = { | DEFAULT_JSONAGENT_REGEX_GRAMMAR = { | ||||||
|     "type": "regex", |     "type": "regex", | ||||||
|     "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>', |     "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_code>', | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| DEFAULT_CODEAGENT_REGEX_GRAMMAR = { | DEFAULT_CODEAGENT_REGEX_GRAMMAR = { | ||||||
|     "type": "regex", |     "type": "regex", | ||||||
|     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", |     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|  |  | ||||||
|  | @ -105,11 +105,11 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: | ||||||
|         raise ValueError(f"Error in parsing the JSON blob: {e}") |         raise ValueError(f"Error in parsing the JSON blob: {e}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_code_blob(code_blob: str) -> str: | def parse_code_blobs(code_blob: str) -> str: | ||||||
|     """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code.""" |     """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code.""" | ||||||
|     pattern = r"```(?:py|python)?\n(.*?)\n```" |     pattern = r"```(?:py|python)?\n(.*?)\n```" | ||||||
|     match = re.search(pattern, code_blob, re.DOTALL) |     matches = re.findall(pattern, code_blob, re.DOTALL) | ||||||
|     if match is None: |     if len(matches) == 0: | ||||||
|         try:  # Maybe the LLM outputted a code blob directly |         try:  # Maybe the LLM outputted a code blob directly | ||||||
|             ast.parse(code_blob) |             ast.parse(code_blob) | ||||||
|             return code_blob |             return code_blob | ||||||
|  | @ -123,7 +123,7 @@ The code blob is invalid, because the regex pattern {pattern} was not found in { | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| final_answer("YOUR FINAL ANSWER HERE") | final_answer("YOUR FINAL ANSWER HERE") | ||||||
| ```<end_action>""".strip() | ```<end_code>""".strip() | ||||||
|             ) |             ) | ||||||
|         raise ValueError( |         raise ValueError( | ||||||
|             f""" |             f""" | ||||||
|  | @ -132,9 +132,9 @@ Thoughts: Your thoughts | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| # Your python code here | # Your python code here | ||||||
| ```<end_action>""".strip() | ```<end_code>""".strip() | ||||||
|         ) |         ) | ||||||
|     return match.group(1).strip() |     return "\n\n".join(match.strip() for match in matches) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: | def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: | ||||||
|  |  | ||||||
|  | @ -15,16 +15,16 @@ | ||||||
| import unittest | import unittest | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from smolagents.utils import parse_code_blob | from smolagents.utils import parse_code_blobs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentTextTests(unittest.TestCase): | class AgentTextTests(unittest.TestCase): | ||||||
|     def test_parse_code_blob(self): |     def test_parse_code_blobs(self): | ||||||
|         with pytest.raises(ValueError): |         with pytest.raises(ValueError): | ||||||
|             parse_code_blob("Wrong blob!") |             parse_code_blobs("Wrong blob!") | ||||||
| 
 | 
 | ||||||
|         # Parsing mardkwon with code blobs should work |         # Parsing mardkwon with code blobs should work | ||||||
|         output = parse_code_blob(""" |         output = parse_code_blobs(""" | ||||||
| Here is how to solve the problem: | Here is how to solve the problem: | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
|  | @ -35,5 +35,25 @@ import numpy as np | ||||||
| 
 | 
 | ||||||
|         # Parsing code blobs should work |         # Parsing code blobs should work | ||||||
|         code_blob = "import numpy as np" |         code_blob = "import numpy as np" | ||||||
|         output = parse_code_blob(code_blob) |         output = parse_code_blobs(code_blob) | ||||||
|         assert output == code_blob |         assert output == code_blob | ||||||
|  | 
 | ||||||
|  |     def test_multiple_code_blobs(self): | ||||||
|  |         test_input = """Here's a function that adds numbers: | ||||||
|  | ```python | ||||||
|  | def add(a, b): | ||||||
|  |     return a + b | ||||||
|  | ``` | ||||||
|  | And here's a function that multiplies them: | ||||||
|  | ```py | ||||||
|  | def multiply(a, b): | ||||||
|  |     return a * b | ||||||
|  | ```""" | ||||||
|  | 
 | ||||||
|  |         expected_output = """def add(a, b): | ||||||
|  |     return a + b | ||||||
|  | 
 | ||||||
|  | def multiply(a, b): | ||||||
|  |     return a * b""" | ||||||
|  |         result = parse_code_blobs(test_input) | ||||||
|  |         assert result == expected_output | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue