From d45c63555fa3e139e812f1a94cba05a61bf869e8 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 6 Jan 2025 22:04:00 +0100 Subject: [PATCH] Pass more tests --- pyproject.toml | 1 + src/smolagents/agents.py | 12 +++++++++--- tests/test_agents.py | 18 +++++++++--------- tests/test_all_docs.py | 8 ++++---- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 13099ff..8120416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,5 @@ test = [ "sqlalchemy", "ruff>=0.5.0", "accelerate", + "soundfile", ] diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index f388c04..accf00e 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -372,7 +372,9 @@ class MultiStepAgent: except Exception as e: return f"Error in generating final LLM output:\n{e}" - def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any: + def execute_tool_call( + self, tool_name: str, arguments: Union[Dict[str, str], str] + ) -> Any: """ Execute tool with the provided input and returns the result. This method replaces arguments with the actual values from the state if they refer to state variables. @@ -515,7 +517,9 @@ You have been provided with these additional arguments, that you can access usin self.planning_interval is not None and step_number % self.planning_interval == 0 ): - self.planning_step(task, is_first_step=(step_number == 0), step=step_number) + self.planning_step( + task, is_first_step=(step_number == 0), step=step_number + ) console.print( Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX) ) @@ -562,7 +566,9 @@ You have been provided with these additional arguments, that you can access usin self.planning_interval is not None and step_number % self.planning_interval == 0 ): - self.planning_step(task, is_first_step=(step_number == 0), step=step_number) + self.planning_step( + task, is_first_step=(step_number == 0), step=step_number + ) console.print( Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX) ) diff --git a/tests/test_agents.py b/tests/test_agents.py index 3ac4bed..5486b2d 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -62,7 +62,7 @@ class FakeToolCallModelImage: else: # We're at step 2 return "final_answer", "image.png", "call_1" - + def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) @@ -364,9 +364,7 @@ class AgentTests(unittest.TestCase): def test_multiagents(self): class FakeModelMultiagentsManagerAgent: - def __call__( - self, messages, stop_sequences=None, grammar=None - ): + def __call__(self, messages, stop_sequences=None, grammar=None): if len(messages) < 3: return """ Thought: Let's call our search agent. @@ -397,10 +395,11 @@ final_answer("Final report.") else: assert "Report on the current US president" in str(messages) return ( - "final_answer", - "Final report.", - "call_0", - ) + "final_answer", + "Final report.", + "call_0", + ) + manager_model = FakeModelMultiagentsManagerAgent() class FakeModelMultiagentsManagedAgent: @@ -412,6 +411,7 @@ final_answer("Final report.") {"report": "Report on the current US president"}, "call_0", ) + managed_model = FakeModelMultiagentsManagedAgent() web_agent = ToolCallingAgent( @@ -435,7 +435,7 @@ final_answer("Final report.") report = manager_code_agent.run("Fake question.") assert report == "Final report." - + manager_toolcalling_agent = ToolCallingAgent( tools=[], model=manager_model, diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index 9177df2..3433352 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -111,10 +111,10 @@ class TestDocs: code_blocks = self.extractor.extract_python_code(content) excluded_snippets = [ "ToolCollection", - "image_generation_tool", # We don't want to run this expensive operation - "from_langchain", # Langchain is not a dependency - "while llm_should_continue(memory):", # This is pseudo code - "ollama_chat/llama3.2" # Exclude ollama building in guided tour + "image_generation_tool", # We don't want to run this expensive operation + "from_langchain", # Langchain is not a dependency + "while llm_should_continue(memory):", # This is pseudo code + "ollama_chat/llama3.2", # Exclude ollama building in guided tour ] code_blocks = [ block