Pass tests
This commit is contained in:
		
							parent
							
								
									a50f9284b3
								
							
						
					
					
						commit
						54d6857da2
					
				
							
								
								
									
										6
									
								
								Makefile
								
								
								
								
							
							
						
						
									
										6
									
								
								Makefile
								
								
								
								
							|  | @ -8,19 +8,19 @@ extra_quality_checks: | ||||||
| 	python utils/check_copies.py | 	python utils/check_copies.py | ||||||
| 	python utils/check_dummies.py | 	python utils/check_dummies.py | ||||||
| 	python utils/check_repo.py | 	python utils/check_repo.py | ||||||
| 	doc-builder style agents docs/source --max_len 119 | 	doc-builder style smolagents docs/source --max_len 119 | ||||||
| 
 | 
 | ||||||
| # this target runs checks on all files
 | # this target runs checks on all files
 | ||||||
| quality: | quality: | ||||||
| 	ruff check $(check_dirs) | 	ruff check $(check_dirs) | ||||||
| 	ruff format --check $(check_dirs) | 	ruff format --check $(check_dirs) | ||||||
| 	doc-builder style agents docs/source --max_len 119 --check_only | 	doc-builder style smolagents docs/source --max_len 119 --check_only | ||||||
| 
 | 
 | ||||||
| # Format source code automatically and check is there are any problems left that need manual fixing
 | # Format source code automatically and check is there are any problems left that need manual fixing
 | ||||||
| style: | style: | ||||||
| 	ruff check $(check_dirs) --fix | 	ruff check $(check_dirs) --fix | ||||||
| 	ruff format $(check_dirs) | 	ruff format $(check_dirs) | ||||||
| 	doc-builder style agents docs/source --max_len 119 | 	doc-builder style smolagents docs/source --max_len 119 | ||||||
| 	 | 	 | ||||||
| # Run tests for the library
 | # Run tests for the library
 | ||||||
| test_big_modeling: | test_big_modeling: | ||||||
|  |  | ||||||
|  | @ -370,8 +370,7 @@ class MultiStepAgent: | ||||||
|         try: |         try: | ||||||
|             return self.model(self.input_messages) |             return self.model(self.input_messages) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Error in generating final LLM output:\n{e}" |             return f"Error in generating final LLM output:\n{e}" | ||||||
|             raise AgentGenerationError(error_msg) |  | ||||||
| 
 | 
 | ||||||
|     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: |     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|  | @ -153,8 +153,8 @@ class DuckDuckGoSearchTool(Tool): | ||||||
|     } |     } | ||||||
|     output_type = "any" |     output_type = "any" | ||||||
| 
 | 
 | ||||||
|     def __init__(self): |     def __init__(self, **kwargs): | ||||||
|         super().__init__(self) |         super().__init__(self, **kwargs) | ||||||
|         try: |         try: | ||||||
|             from duckduckgo_search import DDGS |             from duckduckgo_search import DDGS | ||||||
|         except ImportError: |         except ImportError: | ||||||
|  |  | ||||||
|  | @ -410,7 +410,12 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LiteLLMModel(Model): | class LiteLLMModel(Model): | ||||||
|     def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None): |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_id="anthropic/claude-3-5-sonnet-20240620", | ||||||
|  |         api_base=None, | ||||||
|  |         api_key=None, | ||||||
|  |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs |         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs | ||||||
|  |  | ||||||
|  | @ -517,5 +517,6 @@ __all__ = [ | ||||||
|     "PLAN_UPDATE_FINAL_PLAN_REDACTION", |     "PLAN_UPDATE_FINAL_PLAN_REDACTION", | ||||||
|     "SINGLE_STEP_CODE_SYSTEM_PROMPT", |     "SINGLE_STEP_CODE_SYSTEM_PROMPT", | ||||||
|     "CODE_SYSTEM_PROMPT", |     "CODE_SYSTEM_PROMPT", | ||||||
|  |     "TOOL_CALLING_SYSTEM_PROMPT", | ||||||
|     "MANAGED_AGENT_PROMPT", |     "MANAGED_AGENT_PROMPT", | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | @ -7,8 +7,6 @@ from .utils import BASE_BUILTIN_MODULES | ||||||
| 
 | 
 | ||||||
| _BUILTIN_NAMES = set(vars(builtins)) | _BUILTIN_NAMES = set(vars(builtins)) | ||||||
| 
 | 
 | ||||||
| IMPORTED_PACKAGES = BASE_BUILTIN_MODULES |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class MethodChecker(ast.NodeVisitor): | class MethodChecker(ast.NodeVisitor): | ||||||
|     """ |     """ | ||||||
|  | @ -91,7 +89,7 @@ class MethodChecker(ast.NodeVisitor): | ||||||
|         if isinstance(node.ctx, ast.Load): |         if isinstance(node.ctx, ast.Load): | ||||||
|             if not ( |             if not ( | ||||||
|                 node.id in _BUILTIN_NAMES |                 node.id in _BUILTIN_NAMES | ||||||
|                 or node.id in IMPORTED_PACKAGES |                 or node.id in BASE_BUILTIN_MODULES | ||||||
|                 or node.id in self.arg_names |                 or node.id in self.arg_names | ||||||
|                 or node.id == "self" |                 or node.id == "self" | ||||||
|                 or node.id in self.class_attributes |                 or node.id in self.class_attributes | ||||||
|  | @ -105,7 +103,7 @@ class MethodChecker(ast.NodeVisitor): | ||||||
|         if isinstance(node.func, ast.Name): |         if isinstance(node.func, ast.Name): | ||||||
|             if not ( |             if not ( | ||||||
|                 node.func.id in _BUILTIN_NAMES |                 node.func.id in _BUILTIN_NAMES | ||||||
|                 or node.func.id in IMPORTED_PACKAGES |                 or node.func.id in BASE_BUILTIN_MODULES | ||||||
|                 or node.func.id in self.arg_names |                 or node.func.id in self.arg_names | ||||||
|                 or node.func.id == "self" |                 or node.func.id == "self" | ||||||
|                 or node.func.id in self.class_attributes |                 or node.func.id in self.class_attributes | ||||||
|  |  | ||||||
|  | @ -854,7 +854,7 @@ def load_tool( | ||||||
|         main_module = importlib.import_module("smolagents") |         main_module = importlib.import_module("smolagents") | ||||||
|         tools_module = main_module |         tools_module = main_module | ||||||
|         tool_class = getattr(tools_module, tool_class_name) |         tool_class = getattr(tools_module, tool_class_name) | ||||||
|         return tool_class(model_repo_id, token=token, **kwargs) |         return tool_class(token=token, **kwargs) | ||||||
|     else: |     else: | ||||||
|         return Tool.from_hub( |         return Tool.from_hub( | ||||||
|             task_or_repo_id, |             task_or_repo_id, | ||||||
|  |  | ||||||
|  | @ -104,7 +104,9 @@ class AgentImage(AgentType, ImageType): | ||||||
|         self._raw = None |         self._raw = None | ||||||
|         self._tensor = None |         self._tensor = None | ||||||
| 
 | 
 | ||||||
|         if isinstance(value, ImageType): |         if isinstance(value, AgentImage): | ||||||
|  |             self._raw, self._path, self._tensor = value._raw, value._path, value._tensor | ||||||
|  |         elif isinstance(value, ImageType): | ||||||
|             self._raw = value |             self._raw = value | ||||||
|         elif isinstance(value, bytes): |         elif isinstance(value, bytes): | ||||||
|             self._raw = Image.open(BytesIO(value)) |             self._raw = Image.open(BytesIO(value)) | ||||||
|  |  | ||||||
|  | @ -123,6 +123,7 @@ class TestDocs: | ||||||
|                 "ToolCollection", |                 "ToolCollection", | ||||||
|                 "image_generation_tool", |                 "image_generation_tool", | ||||||
|                 "from_langchain", |                 "from_langchain", | ||||||
|  |                 "while llm_should_continue(memory):", | ||||||
|             ] |             ] | ||||||
|             code_blocks = [ |             code_blocks = [ | ||||||
|                 block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace( |                 block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace( | ||||||
|  |  | ||||||
|  | @ -59,6 +59,6 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|     def test_agent_type_output(self): |     def test_agent_type_output(self): | ||||||
|         inputs = self.create_inputs() |         inputs = self.create_inputs() | ||||||
|         for input_type, input in inputs.items(): |         for input_type, input in inputs.items(): | ||||||
|             output = self.tool(**input) |             output = self.tool(**input, sanitize_inputs_outputs=True) | ||||||
|             agent_type = AGENT_TYPE_MAPPING[input_type] |             agent_type = AGENT_TYPE_MAPPING[input_type] | ||||||
|             self.assertTrue(isinstance(output, agent_type)) |             self.assertTrue(isinstance(output, agent_type)) | ||||||
|  |  | ||||||
|  | @ -55,8 +55,8 @@ final_answer('This is the final answer.') | ||||||
|                 self.last_input_token_count = 10 |                 self.last_input_token_count = 10 | ||||||
|                 self.last_output_token_count = 20 |                 self.last_output_token_count = 20 | ||||||
| 
 | 
 | ||||||
|             def __call__(self, prompt, **kwargs): |             def get_tool_call(self, prompt, **kwargs): | ||||||
|                 return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' |                 return "final_answer", {"answer": "image"}, "fake_id" | ||||||
| 
 | 
 | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|  | @ -96,18 +96,21 @@ final_answer('This is the final answer.') | ||||||
|                 self.last_output_token_count = 20 |                 self.last_output_token_count = 20 | ||||||
| 
 | 
 | ||||||
|             def __call__(self, prompt, **kwargs): |             def __call__(self, prompt, **kwargs): | ||||||
|                 raise AgentError |                 self.last_input_token_count = 10 | ||||||
|  |                 self.last_output_token_count = 0 | ||||||
|  |                 raise Exception("Cannot generate") | ||||||
| 
 | 
 | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|             max_iterations=1, |             max_iterations=1, | ||||||
|         ) |         ) | ||||||
| 
 |  | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(agent.monitor.total_input_token_count, 20) |         self.assertEqual( | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 40) |             agent.monitor.total_input_token_count, 20 | ||||||
|  |         )  # Should have done two monitoring callbacks | ||||||
|  |         self.assertEqual(agent.monitor.total_output_token_count, 0) | ||||||
| 
 | 
 | ||||||
|     def test_streaming_agent_text_output(self): |     def test_streaming_agent_text_output(self): | ||||||
|         def dummy_model(prompt, **kwargs): |         def dummy_model(prompt, **kwargs): | ||||||
|  | @ -132,14 +135,16 @@ final_answer('This is the final answer.') | ||||||
|         self.assertIn("This is the final answer.", final_message.content) |         self.assertIn("This is the final answer.", final_message.content) | ||||||
| 
 | 
 | ||||||
|     def test_streaming_agent_image_output(self): |     def test_streaming_agent_image_output(self): | ||||||
|         def dummy_model(prompt, **kwargs): |         class FakeLLM: | ||||||
|             return ( |             def __init__(self): | ||||||
|                 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' |                 pass | ||||||
|             ) | 
 | ||||||
|  |             def get_tool_call(self, messages, **kwargs): | ||||||
|  |                 return "final_answer", {"answer": "image"}, "fake_id" | ||||||
| 
 | 
 | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=dummy_model, |             model=FakeLLM(), | ||||||
|             max_iterations=1, |             max_iterations=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -148,7 +153,7 @@ final_answer('This is the final answer.') | ||||||
|             stream_to_gradio( |             stream_to_gradio( | ||||||
|                 agent, |                 agent, | ||||||
|                 task="Test task", |                 task="Test task", | ||||||
|                 image=AgentImage(value="path.png"), |                 additional_args=dict(image=AgentImage(value="path.png")), | ||||||
|                 test_mode=True, |                 test_mode=True, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -41,17 +41,16 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
| 
 | 
 | ||||||
|     def test_exact_match_arg(self): |     def test_exact_match_arg(self): | ||||||
|         result = self.tool("(2 / 2) * 4") |         result = self.tool("(2 / 2) * 4") | ||||||
|         self.assertEqual(result, "4.0") |         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||||
| 
 | 
 | ||||||
|     def test_exact_match_kwarg(self): |     def test_exact_match_kwarg(self): | ||||||
|         result = self.tool(code="(2 / 2) * 4") |         result = self.tool(code="(2 / 2) * 4") | ||||||
|         self.assertEqual(result, "4.0") |         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||||
| 
 | 
 | ||||||
|     def test_agent_type_output(self): |     def test_agent_type_output(self): | ||||||
|         inputs = ["2 * 2"] |         inputs = ["2 * 2"] | ||||||
|         output = self.tool(*inputs) |         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] |         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|         print("OKK", type(output), output_type, AGENT_TYPE_MAPPING) |  | ||||||
|         self.assertTrue(isinstance(output, output_type)) |         self.assertTrue(isinstance(output, output_type)) | ||||||
| 
 | 
 | ||||||
|     def test_agent_types_inputs(self): |     def test_agent_types_inputs(self): | ||||||
|  | @ -71,7 +70,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) |                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) | ||||||
| 
 | 
 | ||||||
|         # Should not raise an error |         # Should not raise an error | ||||||
|         output = self.tool(*inputs) |         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] |         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|         self.assertTrue(isinstance(output, output_type)) |         self.assertTrue(isinstance(output, output_type)) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -27,4 +27,4 @@ class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
| 
 | 
 | ||||||
|     def test_exact_match_arg(self): |     def test_exact_match_arg(self): | ||||||
|         result = self.tool("Agents") |         result = self.tool("Agents") | ||||||
|         assert isinstance(result, list) and isinstance(result[0], dict) |         assert isinstance(result, str) | ||||||
|  |  | ||||||
|  | @ -93,7 +93,7 @@ class ToolTesterMixin: | ||||||
| 
 | 
 | ||||||
|     def test_agent_type_output(self): |     def test_agent_type_output(self): | ||||||
|         inputs = create_inputs(self.tool.inputs) |         inputs = create_inputs(self.tool.inputs) | ||||||
|         output = self.tool(**inputs) |         output = self.tool(**inputs, sanitize_inputs_outputs=True) | ||||||
|         if self.tool.output_type != "any": |         if self.tool.output_type != "any": | ||||||
|             agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] |             agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|             self.assertTrue(isinstance(output, agent_type)) |             self.assertTrue(isinstance(output, agent_type)) | ||||||
|  | @ -164,20 +164,20 @@ class ToolTests(unittest.TestCase): | ||||||
|             assert coolfunc.output_type == "number" |             assert coolfunc.output_type == "number" | ||||||
|         assert "docstring has no description for the argument" in str(e) |         assert "docstring has no description for the argument" in str(e) | ||||||
| 
 | 
 | ||||||
|     def test_tool_definition_raises_error_imports_outside_function(self): |     def test_saving_tool_raises_error_imports_outside_function(self): | ||||||
|         with pytest.raises(Exception) as e: |         with pytest.raises(Exception) as e: | ||||||
|             from datetime import datetime |             import numpy as np | ||||||
| 
 | 
 | ||||||
|             @tool |             @tool | ||||||
|             def get_current_time() -> str: |             def get_current_time() -> str: | ||||||
|                 """ |                 """ | ||||||
|                 Gets the current time. |                 Gets the current time. | ||||||
|                 """ |                 """ | ||||||
|                 return str(datetime.now()) |                 return str(np.random.random()) | ||||||
| 
 | 
 | ||||||
|             get_current_time.save("output") |             get_current_time.save("output") | ||||||
| 
 | 
 | ||||||
|         assert "datetime" in str(e) |         assert "np" in str(e) | ||||||
| 
 | 
 | ||||||
|         # Also test with classic definition |         # Also test with classic definition | ||||||
|         with pytest.raises(Exception) as e: |         with pytest.raises(Exception) as e: | ||||||
|  | @ -189,12 +189,12 @@ class ToolTests(unittest.TestCase): | ||||||
|                 output_type = "string" |                 output_type = "string" | ||||||
| 
 | 
 | ||||||
|                 def forward(self): |                 def forward(self): | ||||||
|                     return str(datetime.now()) |                     return str(np.random.random()) | ||||||
| 
 | 
 | ||||||
|             get_current_time = GetCurrentTimeTool() |             get_current_time = GetCurrentTimeTool() | ||||||
|             get_current_time.save("output") |             get_current_time.save("output") | ||||||
| 
 | 
 | ||||||
|         assert "datetime" in str(e) |         assert "np" in str(e) | ||||||
| 
 | 
 | ||||||
|     def test_tool_definition_raises_no_error_imports_in_function(self): |     def test_tool_definition_raises_no_error_imports_in_function(self): | ||||||
|         @tool |         @tool | ||||||
|  |  | ||||||
|  | @ -20,7 +20,6 @@ from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from smolagents.types import AgentAudio, AgentImage, AgentText | from smolagents.types import AgentAudio, AgentImage, AgentText | ||||||
| from transformers.testing_utils import ( | from transformers.testing_utils import ( | ||||||
|     get_tests_dir, |  | ||||||
|     require_soundfile, |     require_soundfile, | ||||||
|     require_torch, |     require_torch, | ||||||
|     require_vision, |     require_vision, | ||||||
|  | @ -91,7 +90,7 @@ class AgentImageTests(unittest.TestCase): | ||||||
|         self.assertTrue(os.path.exists(path)) |         self.assertTrue(os.path.exists(path)) | ||||||
| 
 | 
 | ||||||
|     def test_from_string(self): |     def test_from_string(self): | ||||||
|         path = Path(get_tests_dir("fixtures/")) / "000000039769.png" |         path = Path("tests/fixtures/000000039769.png") | ||||||
|         image = Image.open(path) |         image = Image.open(path) | ||||||
|         agent_type = AgentImage(path) |         agent_type = AgentImage(path) | ||||||
| 
 | 
 | ||||||
|  | @ -103,7 +102,7 @@ class AgentImageTests(unittest.TestCase): | ||||||
|         self.assertTrue(os.path.exists(path)) |         self.assertTrue(os.path.exists(path)) | ||||||
| 
 | 
 | ||||||
|     def test_from_image(self): |     def test_from_image(self): | ||||||
|         path = Path(get_tests_dir("fixtures/")) / "000000039769.png" |         path = Path("tests/fixtures/000000039769.png") | ||||||
|         image = Image.open(path) |         image = Image.open(path) | ||||||
|         agent_type = AgentImage(image) |         agent_type = AgentImage(image) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue