Fix: source code inspection in interactive shells (#281)
* Support interactive shells for source inspection * Add tool save e2e tests
This commit is contained in:
		
							parent
							
								
									5d6502ae1d
								
							
						
					
					
						commit
						6196958deb
					
				|  | @ -60,6 +60,7 @@ all = [ | ||||||
|   "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]", |   "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]", | ||||||
| ] | ] | ||||||
| test = [ | test = [ | ||||||
|  |   "ipython>=8.31.0", # for interactive environment tests | ||||||
|   "pytest>=8.1.0", |   "pytest>=8.1.0", | ||||||
|   "python-dotenv>=1.0.1",  # For test_all_docs |   "python-dotenv>=1.0.1",  # For test_all_docs | ||||||
|   "smolagents[all]", |   "smolagents[all]", | ||||||
|  |  | ||||||
|  | @ -1,10 +1,9 @@ | ||||||
| import ast | import ast | ||||||
| import builtins | import builtins | ||||||
| import inspect | import inspect | ||||||
| import textwrap |  | ||||||
| from typing import Set | from typing import Set | ||||||
| 
 | 
 | ||||||
| from .utils import BASE_BUILTIN_MODULES | from .utils import BASE_BUILTIN_MODULES, get_source | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| _BUILTIN_NAMES = set(vars(builtins)) | _BUILTIN_NAMES = set(vars(builtins)) | ||||||
|  | @ -132,7 +131,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: | ||||||
|     """ |     """ | ||||||
|     errors = [] |     errors = [] | ||||||
| 
 | 
 | ||||||
|     source = textwrap.dedent(inspect.getsource(cls)) |     source = get_source(cls) | ||||||
| 
 | 
 | ||||||
|     tree = ast.parse(source) |     tree = ast.parse(source) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -46,7 +46,7 @@ from ._function_type_hints_utils import ( | ||||||
| ) | ) | ||||||
| from .tool_validation import MethodChecker, validate_tool_attributes | from .tool_validation import MethodChecker, validate_tool_attributes | ||||||
| from .types import handle_agent_input_types, handle_agent_output_types | from .types import handle_agent_input_types, handle_agent_output_types | ||||||
| from .utils import _is_package_available, _is_pillow_available, instance_to_source | from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | @ -220,8 +220,8 @@ class Tool: | ||||||
|         # Save tool file |         # Save tool file | ||||||
|         if type(self).__name__ == "SimpleTool": |         if type(self).__name__ == "SimpleTool": | ||||||
|             # Check that imports are self-contained |             # Check that imports are self-contained | ||||||
|             source_code = inspect.getsource(self.forward).replace("@tool", "") |             source_code = get_source(self.forward).replace("@tool", "") | ||||||
|             forward_node = ast.parse(textwrap.dedent(source_code)) |             forward_node = ast.parse(source_code) | ||||||
|             # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code |             # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code | ||||||
|             method_checker = MethodChecker(set()) |             method_checker = MethodChecker(set()) | ||||||
|             method_checker.visit(forward_node) |             method_checker.visit(forward_node) | ||||||
|  | @ -229,7 +229,7 @@ class Tool: | ||||||
|             if len(method_checker.errors) > 0: |             if len(method_checker.errors) > 0: | ||||||
|                 raise (ValueError("\n".join(method_checker.errors))) |                 raise (ValueError("\n".join(method_checker.errors))) | ||||||
| 
 | 
 | ||||||
|             forward_source_code = inspect.getsource(self.forward) |             forward_source_code = get_source(self.forward) | ||||||
|             tool_code = textwrap.dedent( |             tool_code = textwrap.dedent( | ||||||
|                 f""" |                 f""" | ||||||
|             from smolagents import Tool |             from smolagents import Tool | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ import importlib.util | ||||||
| import inspect | import inspect | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
|  | import textwrap | ||||||
| import types | import types | ||||||
| from enum import IntEnum | from enum import IntEnum | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
|  | @ -221,7 +222,7 @@ def get_method_source(method): | ||||||
|     """Get source code for a method, including bound methods.""" |     """Get source code for a method, including bound methods.""" | ||||||
|     if isinstance(method, types.MethodType): |     if isinstance(method, types.MethodType): | ||||||
|         method = method.__func__ |         method = method.__func__ | ||||||
|     return inspect.getsource(method).strip() |     return get_source(method) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_same_method(method1, method2): | def is_same_method(method1, method2): | ||||||
|  | @ -295,7 +296,7 @@ def instance_to_source(instance, base_cls=None): | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     for name, method in methods.items(): |     for name, method in methods.items(): | ||||||
|         method_source = inspect.getsource(method) |         method_source = get_source(method) | ||||||
|         # Clean up the indentation |         # Clean up the indentation | ||||||
|         method_lines = method_source.split("\n") |         method_lines = method_source.split("\n") | ||||||
|         first_line = method_lines[0] |         first_line = method_lines[0] | ||||||
|  | @ -330,4 +331,56 @@ def instance_to_source(instance, base_cls=None): | ||||||
|     return "\n".join(final_lines) |     return "\n".join(final_lines) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_source(obj) -> str: | ||||||
|  |     """Get the source code of a class or callable object (e.g.: function, method). | ||||||
|  |     First attempts to get the source code using `inspect.getsource`. | ||||||
|  |     In a dynamic environment (e.g.: Jupyter, IPython), if this fails, | ||||||
|  |     falls back to retrieving the source code from the current interactive shell session. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         obj: A class or callable object (e.g.: function, method) | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         str: The source code of the object, dedented and stripped | ||||||
|  | 
 | ||||||
|  |     Raises: | ||||||
|  |         TypeError: If object is not a class or callable | ||||||
|  |         OSError: If source code cannot be retrieved from any source | ||||||
|  |         ValueError: If source cannot be found in IPython history | ||||||
|  | 
 | ||||||
|  |     Note: | ||||||
|  |         TODO: handle Python standard REPL | ||||||
|  |     """ | ||||||
|  |     if not (isinstance(obj, type) or callable(obj)): | ||||||
|  |         raise TypeError(f"Expected class or callable, got {type(obj)}") | ||||||
|  | 
 | ||||||
|  |     inspect_error = None | ||||||
|  |     try: | ||||||
|  |         return textwrap.dedent(inspect.getsource(obj)).strip() | ||||||
|  |     except OSError as e: | ||||||
|  |         # let's keep track of the exception to raise it if all further methods fail | ||||||
|  |         inspect_error = e | ||||||
|  |     try: | ||||||
|  |         import IPython | ||||||
|  | 
 | ||||||
|  |         shell = IPython.get_ipython() | ||||||
|  |         if not shell: | ||||||
|  |             raise ImportError("No active IPython shell found") | ||||||
|  |         all_cells = "\n".join(shell.user_ns.get("In", [])).strip() | ||||||
|  |         if not all_cells: | ||||||
|  |             raise ValueError("No code cells found in IPython session") | ||||||
|  | 
 | ||||||
|  |         tree = ast.parse(all_cells) | ||||||
|  |         for node in ast.walk(tree): | ||||||
|  |             if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__: | ||||||
|  |                 return textwrap.dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip() | ||||||
|  |         raise ValueError(f"Could not find source code for {obj.__name__} in IPython history") | ||||||
|  |     except ImportError: | ||||||
|  |         # IPython is not available, let's just raise the original inspect error | ||||||
|  |         raise inspect_error | ||||||
|  |     except ValueError as e: | ||||||
|  |         # IPython is available but we couldn't find the source code, let's raise the error | ||||||
|  |         raise e from inspect_error | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["AgentError"] | __all__ = ["AgentError"] | ||||||
|  |  | ||||||
|  | @ -12,11 +12,19 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import inspect | ||||||
|  | import os | ||||||
|  | import pathlib | ||||||
|  | import tempfile | ||||||
|  | import textwrap | ||||||
| import unittest | import unittest | ||||||
| 
 | 
 | ||||||
| import pytest | import pytest | ||||||
|  | from IPython.core.interactiveshell import InteractiveShell | ||||||
| 
 | 
 | ||||||
| from smolagents.utils import parse_code_blobs | from smolagents import Tool | ||||||
|  | from smolagents.tools import tool | ||||||
|  | from smolagents.utils import get_source, parse_code_blobs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentTextTests(unittest.TestCase): | class AgentTextTests(unittest.TestCase): | ||||||
|  | @ -58,3 +66,287 @@ def multiply(a, b): | ||||||
|     return a * b""" |     return a * b""" | ||||||
|         result = parse_code_blobs(test_input) |         result = parse_code_blobs(test_input) | ||||||
|         assert result == expected_output |         assert result == expected_output | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture(scope="function") | ||||||
|  | def ipython_shell(): | ||||||
|  |     """Reset IPython shell before and after each test.""" | ||||||
|  |     shell = InteractiveShell.instance() | ||||||
|  |     shell.reset()  # Clean before test | ||||||
|  |     yield shell | ||||||
|  |     shell.reset()  # Clean after test | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "obj_name, code_blob", | ||||||
|  |     [ | ||||||
|  |         ("test_func", "def test_func():\n    return 42"), | ||||||
|  |         ("TestClass", "class TestClass:\n    ..."), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_get_source_ipython(ipython_shell, obj_name, code_blob): | ||||||
|  |     ipython_shell.run_cell(code_blob, store_history=True) | ||||||
|  |     obj = ipython_shell.user_ns[obj_name] | ||||||
|  |     assert get_source(obj) == code_blob | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_source_standard_class(): | ||||||
|  |     class TestClass: ... | ||||||
|  | 
 | ||||||
|  |     source = get_source(TestClass) | ||||||
|  |     assert source == "class TestClass: ..." | ||||||
|  |     assert source == textwrap.dedent(inspect.getsource(TestClass)).strip() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_source_standard_function(): | ||||||
|  |     def test_func(): ... | ||||||
|  | 
 | ||||||
|  |     source = get_source(test_func) | ||||||
|  |     assert source == "def test_func(): ..." | ||||||
|  |     assert source == textwrap.dedent(inspect.getsource(test_func)).strip() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_source_ipython_errors_empty_cells(ipython_shell): | ||||||
|  |     test_code = textwrap.dedent("""class TestClass:\n    ...""").strip() | ||||||
|  |     ipython_shell.user_ns["In"] = [""] | ||||||
|  |     exec(test_code) | ||||||
|  |     with pytest.raises(ValueError, match="No code cells found in IPython session"): | ||||||
|  |         get_source(locals()["TestClass"]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_source_ipython_errors_definition_not_found(ipython_shell): | ||||||
|  |     test_code = textwrap.dedent("""class TestClass:\n    ...""").strip() | ||||||
|  |     ipython_shell.user_ns["In"] = ["", "print('No class definition here')"] | ||||||
|  |     exec(test_code) | ||||||
|  |     with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"): | ||||||
|  |         get_source(locals()["TestClass"]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_source_ipython_errors_type_error(): | ||||||
|  |     with pytest.raises(TypeError, match="Expected class or callable"): | ||||||
|  |         get_source(None) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_e2e_class_tool_save(): | ||||||
|  |     class TestTool(Tool): | ||||||
|  |         name = "test_tool" | ||||||
|  |         description = "Test tool description" | ||||||
|  |         inputs = { | ||||||
|  |             "task": { | ||||||
|  |                 "type": "string", | ||||||
|  |                 "description": "tool input", | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         output_type = "string" | ||||||
|  | 
 | ||||||
|  |         def forward(self, task: str): | ||||||
|  |             import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |             return task | ||||||
|  | 
 | ||||||
|  |     test_tool = TestTool() | ||||||
|  |     with tempfile.TemporaryDirectory() as tmp_dir: | ||||||
|  |         test_tool.save(tmp_dir) | ||||||
|  |         assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "tool.py").read_text() | ||||||
|  |             == """from smolagents.tools import Tool | ||||||
|  | import IPython | ||||||
|  | 
 | ||||||
|  | class TestTool(Tool): | ||||||
|  |     name = "test_tool" | ||||||
|  |     description = "Test tool description" | ||||||
|  |     inputs = {'task': {'type': 'string', 'description': 'tool input'}} | ||||||
|  |     output_type = "string" | ||||||
|  | 
 | ||||||
|  |     def forward(self, task: str): | ||||||
|  |         import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |         return task | ||||||
|  | 
 | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         self.is_initialized = False | ||||||
|  | """ | ||||||
|  |         ) | ||||||
|  |         requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | ||||||
|  |         assert requirements == {"IPython", "smolagents"} | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "app.py").read_text() | ||||||
|  |             == """from smolagents import launch_gradio_demo | ||||||
|  | from typing import Optional | ||||||
|  | from tool import TestTool | ||||||
|  | 
 | ||||||
|  | tool = TestTool() | ||||||
|  | 
 | ||||||
|  | launch_gradio_demo(tool) | ||||||
|  | """ | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_e2e_ipython_class_tool_save(): | ||||||
|  |     shell = InteractiveShell.instance() | ||||||
|  |     with tempfile.TemporaryDirectory() as tmp_dir: | ||||||
|  |         code_blob = textwrap.dedent(f""" | ||||||
|  |         from smolagents.tools import Tool | ||||||
|  |         class TestTool(Tool): | ||||||
|  |             name = "test_tool" | ||||||
|  |             description = "Test tool description" | ||||||
|  |             inputs = {{"task": {{"type": "string", | ||||||
|  |                     "description": "tool input", | ||||||
|  |                 }} | ||||||
|  |             }} | ||||||
|  |             output_type = "string" | ||||||
|  | 
 | ||||||
|  |             def forward(self, task: str): | ||||||
|  |                 import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |                 return task | ||||||
|  |         TestTool().save("{tmp_dir}") | ||||||
|  |     """) | ||||||
|  |         assert shell.run_cell(code_blob, store_history=True).success | ||||||
|  |         assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "tool.py").read_text() | ||||||
|  |             == """from smolagents.tools import Tool | ||||||
|  | import IPython | ||||||
|  | 
 | ||||||
|  | class TestTool(Tool): | ||||||
|  |     name = "test_tool" | ||||||
|  |     description = "Test tool description" | ||||||
|  |     inputs = {'task': {'type': 'string', 'description': 'tool input'}} | ||||||
|  |     output_type = "string" | ||||||
|  | 
 | ||||||
|  |     def forward(self, task: str): | ||||||
|  |         import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |         return task | ||||||
|  | 
 | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         self.is_initialized = False | ||||||
|  | """ | ||||||
|  |         ) | ||||||
|  |         requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | ||||||
|  |         assert requirements == {"IPython", "smolagents"} | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "app.py").read_text() | ||||||
|  |             == """from smolagents import launch_gradio_demo | ||||||
|  | from typing import Optional | ||||||
|  | from tool import TestTool | ||||||
|  | 
 | ||||||
|  | tool = TestTool() | ||||||
|  | 
 | ||||||
|  | launch_gradio_demo(tool) | ||||||
|  | """ | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_e2e_function_tool_save(): | ||||||
|  |     @tool | ||||||
|  |     def test_tool(task: str) -> str: | ||||||
|  |         """ | ||||||
|  |         Test tool description | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             task: tool input | ||||||
|  |         """ | ||||||
|  |         import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |         return task | ||||||
|  | 
 | ||||||
|  |     with tempfile.TemporaryDirectory() as tmp_dir: | ||||||
|  |         test_tool.save(tmp_dir) | ||||||
|  |         assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "tool.py").read_text() | ||||||
|  |             == """from smolagents import Tool | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | class SimpleTool(Tool): | ||||||
|  |     name = "test_tool" | ||||||
|  |     description = "Test tool description" | ||||||
|  |     inputs = {"task":{"type":"string","description":"tool input"}} | ||||||
|  |     output_type = "string" | ||||||
|  | 
 | ||||||
|  |     def forward(self, task: str) -> str: | ||||||
|  |         \""" | ||||||
|  |         Test tool description | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             task: tool input | ||||||
|  |         \""" | ||||||
|  |         import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |         return task""" | ||||||
|  |         ) | ||||||
|  |         requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | ||||||
|  |         assert requirements == {"smolagents"}  # FIXME: IPython should be in the requirements | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "app.py").read_text() | ||||||
|  |             == """from smolagents import launch_gradio_demo | ||||||
|  | from typing import Optional | ||||||
|  | from tool import SimpleTool | ||||||
|  | 
 | ||||||
|  | tool = SimpleTool() | ||||||
|  | 
 | ||||||
|  | launch_gradio_demo(tool) | ||||||
|  | """ | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_e2e_ipython_function_tool_save(): | ||||||
|  |     shell = InteractiveShell.instance() | ||||||
|  |     with tempfile.TemporaryDirectory() as tmp_dir: | ||||||
|  |         code_blob = textwrap.dedent(f""" | ||||||
|  |         from smolagents import tool | ||||||
|  | 
 | ||||||
|  |         @tool | ||||||
|  |         def test_tool(task: str) -> str: | ||||||
|  |             \""" | ||||||
|  |             Test tool description | ||||||
|  | 
 | ||||||
|  |             Args: | ||||||
|  |                 task: tool input | ||||||
|  |             \""" | ||||||
|  |             import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |             return task | ||||||
|  | 
 | ||||||
|  |         test_tool.save("{tmp_dir}") | ||||||
|  |         """) | ||||||
|  |         assert shell.run_cell(code_blob, store_history=True).success | ||||||
|  |         assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "tool.py").read_text() | ||||||
|  |             == """from smolagents import Tool | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | class SimpleTool(Tool): | ||||||
|  |     name = "test_tool" | ||||||
|  |     description = "Test tool description" | ||||||
|  |     inputs = {"task":{"type":"string","description":"tool input"}} | ||||||
|  |     output_type = "string" | ||||||
|  | 
 | ||||||
|  |     def forward(self, task: str) -> str: | ||||||
|  |         \""" | ||||||
|  |         Test tool description | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             task: tool input | ||||||
|  |         \""" | ||||||
|  |         import IPython  # noqa: F401 | ||||||
|  | 
 | ||||||
|  |         return task""" | ||||||
|  |         ) | ||||||
|  |         requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | ||||||
|  |         assert requirements == {"smolagents"}  # FIXME: IPython should be in the requirements | ||||||
|  |         assert ( | ||||||
|  |             pathlib.Path(tmp_dir, "app.py").read_text() | ||||||
|  |             == """from smolagents import launch_gradio_demo | ||||||
|  | from typing import Optional | ||||||
|  | from tool import SimpleTool | ||||||
|  | 
 | ||||||
|  | tool = SimpleTool() | ||||||
|  | 
 | ||||||
|  | launch_gradio_demo(tool) | ||||||
|  | """ | ||||||
|  |         ) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue