diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 05be772..d8c8c4f 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -60,7 +60,7 @@ from .utils import ( AgentMaxStepsError, AgentParsingError, console, - parse_code_blob, + parse_code_blobs, parse_json_tool_call, truncate_content, ) @@ -894,7 +894,7 @@ class CodeAgent(MultiStepAgent): ) llm_output = self.model( self.input_messages, - stop_sequences=["", "Observation:"], + stop_sequences=["", "Observation:"], **additional_args, ) log_entry.llm_output = llm_output @@ -920,7 +920,7 @@ class CodeAgent(MultiStepAgent): # Parse try: - code_action = fix_final_answer_code(parse_code_blob(llm_output)) + code_action = fix_final_answer_code(parse_code_blobs(llm_output)) except Exception as e: error_msg = ( f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 4c66b23..32d08f4 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -39,12 +39,12 @@ logger = logging.getLogger(__name__) DEFAULT_JSONAGENT_REGEX_GRAMMAR = { "type": "regex", - "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n', + "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n', } DEFAULT_CODEAGENT_REGEX_GRAMMAR = { "type": "regex", - "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```", + "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```", } try: diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index fe006fd..4e3e727 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -105,11 +105,11 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: raise ValueError(f"Error in parsing the JSON blob: {e}") -def parse_code_blob(code_blob: str) -> str: +def parse_code_blobs(code_blob: str) -> str: """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code.""" pattern = r"```(?:py|python)?\n(.*?)\n```" - match = re.search(pattern, code_blob, re.DOTALL) - if match is None: + matches = re.findall(pattern, code_blob, re.DOTALL) + if len(matches) == 0: try: # Maybe the LLM outputted a code blob directly ast.parse(code_blob) return code_blob @@ -123,7 +123,7 @@ The code blob is invalid, because the regex pattern {pattern} was not found in { Code: ```py final_answer("YOUR FINAL ANSWER HERE") -```""".strip() +```""".strip() ) raise ValueError( f""" @@ -132,9 +132,9 @@ Thoughts: Your thoughts Code: ```py # Your python code here -```""".strip() +```""".strip() ) - return match.group(1).strip() + return "\n\n".join(match.strip() for match in matches) def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: diff --git a/tests/test_utils.py b/tests/test_utils.py index 1ec6343..0a661a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,16 +15,16 @@ import unittest import pytest -from smolagents.utils import parse_code_blob +from smolagents.utils import parse_code_blobs class AgentTextTests(unittest.TestCase): - def test_parse_code_blob(self): + def test_parse_code_blobs(self): with pytest.raises(ValueError): - parse_code_blob("Wrong blob!") + parse_code_blobs("Wrong blob!") # Parsing mardkwon with code blobs should work - output = parse_code_blob(""" + output = parse_code_blobs(""" Here is how to solve the problem: Code: ```py @@ -35,5 +35,25 @@ import numpy as np # Parsing code blobs should work code_blob = "import numpy as np" - output = parse_code_blob(code_blob) + output = parse_code_blobs(code_blob) assert output == code_blob + + def test_multiple_code_blobs(self): + test_input = """Here's a function that adds numbers: +```python +def add(a, b): + return a + b +``` +And here's a function that multiplies them: +```py +def multiply(a, b): + return a * b +```""" + + expected_output = """def add(a, b): + return a + b + +def multiply(a, b): + return a * b""" + result = parse_code_blobs(test_input) + assert result == expected_output