Pass tests

This commit is contained in:
Aymeric 2024-12-30 18:03:53 +01:00
parent a50f9284b3
commit 54d6857da2
15 changed files with 52 additions and 43 deletions

View File

@ -8,19 +8,19 @@ extra_quality_checks:
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_repo.py
doc-builder style agents docs/source --max_len 119
doc-builder style smolagents docs/source --max_len 119
# this target runs checks on all files
quality:
ruff check $(check_dirs)
ruff format --check $(check_dirs)
doc-builder style agents docs/source --max_len 119 --check_only
doc-builder style smolagents docs/source --max_len 119 --check_only
# Format source code automatically and check is there are any problems left that need manual fixing
style:
ruff check $(check_dirs) --fix
ruff format $(check_dirs)
doc-builder style agents docs/source --max_len 119
doc-builder style smolagents docs/source --max_len 119
# Run tests for the library
test_big_modeling:

View File

@ -370,8 +370,7 @@ class MultiStepAgent:
try:
return self.model(self.input_messages)
except Exception as e:
error_msg = f"Error in generating final LLM output:\n{e}"
raise AgentGenerationError(error_msg)
return f"Error in generating final LLM output:\n{e}"
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
"""

View File

@ -153,8 +153,8 @@ class DuckDuckGoSearchTool(Tool):
}
output_type = "any"
def __init__(self):
super().__init__(self)
def __init__(self, **kwargs):
super().__init__(self, **kwargs)
try:
from duckduckgo_search import DDGS
except ImportError:

View File

@ -410,7 +410,12 @@ class TransformersModel(Model):
class LiteLLMModel(Model):
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None):
def __init__(
self,
model_id="anthropic/claude-3-5-sonnet-20240620",
api_base=None,
api_key=None,
):
super().__init__()
self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs

View File

@ -517,5 +517,6 @@ __all__ = [
"PLAN_UPDATE_FINAL_PLAN_REDACTION",
"SINGLE_STEP_CODE_SYSTEM_PROMPT",
"CODE_SYSTEM_PROMPT",
"TOOL_CALLING_SYSTEM_PROMPT",
"MANAGED_AGENT_PROMPT",
]

View File

@ -7,8 +7,6 @@ from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins))
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
class MethodChecker(ast.NodeVisitor):
"""
@ -91,7 +89,7 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.ctx, ast.Load):
if not (
node.id in _BUILTIN_NAMES
or node.id in IMPORTED_PACKAGES
or node.id in BASE_BUILTIN_MODULES
or node.id in self.arg_names
or node.id == "self"
or node.id in self.class_attributes
@ -105,7 +103,7 @@ class MethodChecker(ast.NodeVisitor):
if isinstance(node.func, ast.Name):
if not (
node.func.id in _BUILTIN_NAMES
or node.func.id in IMPORTED_PACKAGES
or node.func.id in BASE_BUILTIN_MODULES
or node.func.id in self.arg_names
or node.func.id == "self"
or node.func.id in self.class_attributes

View File

@ -854,7 +854,7 @@ def load_tool(
main_module = importlib.import_module("smolagents")
tools_module = main_module
tool_class = getattr(tools_module, tool_class_name)
return tool_class(model_repo_id, token=token, **kwargs)
return tool_class(token=token, **kwargs)
else:
return Tool.from_hub(
task_or_repo_id,

View File

@ -104,7 +104,9 @@ class AgentImage(AgentType, ImageType):
self._raw = None
self._tensor = None
if isinstance(value, ImageType):
if isinstance(value, AgentImage):
self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
elif isinstance(value, ImageType):
self._raw = value
elif isinstance(value, bytes):
self._raw = Image.open(BytesIO(value))

View File

@ -123,6 +123,7 @@ class TestDocs:
"ToolCollection",
"image_generation_tool",
"from_langchain",
"while llm_should_continue(memory):",
]
code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace(

View File

@ -59,6 +59,6 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def test_agent_type_output(self):
inputs = self.create_inputs()
for input_type, input in inputs.items():
output = self.tool(**input)
output = self.tool(**input, sanitize_inputs_outputs=True)
agent_type = AGENT_TYPE_MAPPING[input_type]
self.assertTrue(isinstance(output, agent_type))

View File

@ -55,8 +55,8 @@ final_answer('This is the final answer.')
self.last_input_token_count = 10
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
def get_tool_call(self, prompt, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent(
tools=[],
@ -96,18 +96,21 @@ final_answer('This is the final answer.')
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
raise AgentError
self.last_input_token_count = 10
self.last_output_token_count = 0
raise Exception("Cannot generate")
agent = CodeAgent(
tools=[],
model=FakeLLMModel(),
max_iterations=1,
)
agent.run("Fake task")
self.assertEqual(agent.monitor.total_input_token_count, 20)
self.assertEqual(agent.monitor.total_output_token_count, 40)
self.assertEqual(
agent.monitor.total_input_token_count, 20
) # Should have done two monitoring callbacks
self.assertEqual(agent.monitor.total_output_token_count, 0)
def test_streaming_agent_text_output(self):
def dummy_model(prompt, **kwargs):
@ -132,14 +135,16 @@ final_answer('This is the final answer.')
self.assertIn("This is the final answer.", final_message.content)
def test_streaming_agent_image_output(self):
def dummy_model(prompt, **kwargs):
return (
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
)
class FakeLLM:
def __init__(self):
pass
def get_tool_call(self, messages, **kwargs):
return "final_answer", {"answer": "image"}, "fake_id"
agent = ToolCallingAgent(
tools=[],
model=dummy_model,
model=FakeLLM(),
max_iterations=1,
)
@ -148,7 +153,7 @@ final_answer('This is the final answer.')
stream_to_gradio(
agent,
task="Test task",
image=AgentImage(value="path.png"),
additional_args=dict(image=AgentImage(value="path.png")),
test_mode=True,
)
)

View File

@ -41,17 +41,16 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "4.0")
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_exact_match_kwarg(self):
result = self.tool(code="(2 / 2) * 4")
self.assertEqual(result, "4.0")
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_agent_type_output(self):
inputs = ["2 * 2"]
output = self.tool(*inputs)
output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self):
@ -71,7 +70,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error
output = self.tool(*inputs)
output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))

View File

@ -27,4 +27,4 @@ class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
def test_exact_match_arg(self):
result = self.tool("Agents")
assert isinstance(result, list) and isinstance(result[0], dict)
assert isinstance(result, str)

View File

@ -93,7 +93,7 @@ class ToolTesterMixin:
def test_agent_type_output(self):
inputs = create_inputs(self.tool.inputs)
output = self.tool(**inputs)
output = self.tool(**inputs, sanitize_inputs_outputs=True)
if self.tool.output_type != "any":
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type))
@ -164,20 +164,20 @@ class ToolTests(unittest.TestCase):
assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e)
def test_tool_definition_raises_error_imports_outside_function(self):
def test_saving_tool_raises_error_imports_outside_function(self):
with pytest.raises(Exception) as e:
from datetime import datetime
import numpy as np
@tool
def get_current_time() -> str:
"""
Gets the current time.
"""
return str(datetime.now())
return str(np.random.random())
get_current_time.save("output")
assert "datetime" in str(e)
assert "np" in str(e)
# Also test with classic definition
with pytest.raises(Exception) as e:
@ -189,12 +189,12 @@ class ToolTests(unittest.TestCase):
output_type = "string"
def forward(self):
return str(datetime.now())
return str(np.random.random())
get_current_time = GetCurrentTimeTool()
get_current_time.save("output")
assert "datetime" in str(e)
assert "np" in str(e)
def test_tool_definition_raises_no_error_imports_in_function(self):
@tool

View File

@ -20,7 +20,6 @@ from pathlib import Path
from smolagents.types import AgentAudio, AgentImage, AgentText
from transformers.testing_utils import (
get_tests_dir,
require_soundfile,
require_torch,
require_vision,
@ -91,7 +90,7 @@ class AgentImageTests(unittest.TestCase):
self.assertTrue(os.path.exists(path))
def test_from_string(self):
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
path = Path("tests/fixtures/000000039769.png")
image = Image.open(path)
agent_type = AgentImage(path)
@ -103,7 +102,7 @@ class AgentImageTests(unittest.TestCase):
self.assertTrue(os.path.exists(path))
def test_from_image(self):
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
path = Path("tests/fixtures/000000039769.png")
image = Image.open(path)
agent_type = AgentImage(image)