From 06066437fdb4bde4ace65400f88bd347893c92d0 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 16 Dec 2024 15:46:47 +0100 Subject: [PATCH] Formatting --- src/agents/__init__.py | 10 +-- src/agents/agents.py | 26 ++++-- src/agents/default_tools.py | 8 +- src/agents/docker_python_executor.py | 130 +++++++++++++-------------- src/agents/gradio_ui.py | 3 +- src/agents/llm_engines.py | 35 ++++++-- src/agents/local_python_executor.py | 3 +- src/agents/monitoring.py | 3 +- src/agents/prompts.py | 8 +- src/agents/search.py | 3 +- src/agents/tools.py | 125 ++++++++++++++------------ src/agents/types.py | 3 +- src/agents/utils.py | 15 ++-- tests/test_agents.py | 25 ++---- tests/test_all_docs.py | 50 ++++++----- tests/test_final_answer.py | 5 +- tests/test_tools_common.py | 19 ++-- tests/test_utils.py | 8 +- 18 files changed, 275 insertions(+), 204 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 812584f..b24e056 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -18,11 +18,7 @@ __version__ = "0.1.0" from typing import TYPE_CHECKING -from transformers.utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, -) +from transformers.utils import _LazyModule from transformers.utils.import_utils import define_import_structure @@ -43,4 +39,6 @@ else: import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, _file, define_import_structure(_file), module_spec=__spec__ + ) diff --git a/src/agents/agents.py b/src/agents/agents.py index ecd3f87..2a67665 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -79,8 +79,9 @@ class AgentGenerationError(AgentError): pass + @dataclass -class ToolCall(): +class ToolCall: tool_name: str tool_arguments: Any @@ -146,13 +147,17 @@ Here is a list of the team members that you can call:""" def format_prompt_with_managed_agents_descriptions( - prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None + prompt_template, + managed_agents, + agent_descriptions_placeholder: Optional[str] = None, ) -> str: if agent_descriptions_placeholder is None: agent_descriptions_placeholder = "{{managed_agents_descriptions}}" if agent_descriptions_placeholder not in prompt_template: print("PROMPT TEMPLLL", prompt_template) - raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'") + raise ValueError( + f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'" + ) if len(managed_agents.keys()) > 0: return prompt_template.replace( agent_descriptions_placeholder, show_agents_descriptions(managed_agents) @@ -970,7 +975,9 @@ class CodeAgent(ReactAgent): error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" raise AgentParsingError(error_msg) - log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action) + log_entry.tool_call = ToolCall( + tool_name="python_interpreter", tool_arguments=code_action + ) # Execute if self.verbose: @@ -1075,4 +1082,13 @@ And even if your task resolution is not successful, please return as much contex else: return output -__all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"] \ No newline at end of file + +__all__ = [ + "AgentError", + "BaseAgent", + "ManagedAgent", + "ReactAgent", + "CodeAgent", + "JsonAgent", + "Toolbox", +] diff --git a/src/agents/default_tools.py b/src/agents/default_tools.py index ab3fc38..60704bd 100644 --- a/src/agents/default_tools.py +++ b/src/agents/default_tools.py @@ -127,7 +127,10 @@ class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." inputs = { - "code": {"type": "string", "description": "The python code to run in interpreter"} + "code": { + "type": "string", + "description": "The python code to run in interpreter", + } } output_type = "string" @@ -186,4 +189,5 @@ class UserInputTool(Tool): user_input = input(f"{question} => ") return user_input -__all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"] \ No newline at end of file + +__all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"] diff --git a/src/agents/docker_python_executor.py b/src/agents/docker_python_executor.py index 9b2119e..96d4446 100644 --- a/src/agents/docker_python_executor.py +++ b/src/agents/docker_python_executor.py @@ -9,71 +9,73 @@ from typing import Optional, Dict, Tuple, Set, Any import types from .default_tools import BASE_PYTHON_TOOLS + class StateManager: def __init__(self, work_dir: Path): self.work_dir = work_dir self.state_file = work_dir / "interpreter_state.pickle" self.imports_file = work_dir / "imports.txt" - self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$') + self.import_pattern = re.compile(r"^(?:from\s+[\w.]+\s+)?import\s+.+$") self.imports: Set[str] = set() - + def is_import_statement(self, code: str) -> bool: """Check if a line of code is an import statement.""" return bool(self.import_pattern.match(code.strip())) - + def track_imports(self, code: str): """Track import statements for later use.""" - for line in code.split('\n'): + for line in code.split("\n"): if self.is_import_statement(line.strip()): self.imports.add(line.strip()) def save_state(self, locals_dict: Dict[str, Any], executor: str): """ Save the current state of variables and imports. - + Args: locals_dict: Dictionary of local variables executor: 'docker' or 'local' to indicate source """ # Filter out modules, functions, and special variables state_dict = { - 'variables': { - k: v for k, v in locals_dict.items() + "variables": { + k: v + for k, v in locals_dict.items() if not ( - k.startswith('_') + k.startswith("_") or callable(v) or isinstance(v, type) or isinstance(v, types.ModuleType) ) }, - 'imports': list(self.imports), - 'source': executor + "imports": list(self.imports), + "source": executor, } - with open(self.state_file, 'wb') as f: + with open(self.state_file, "wb") as f: pickle.dump(state_dict, f) - + def load_state(self, executor: str) -> Dict[str, Any]: """ Load the saved state and handle imports. - + Args: executor: 'docker' or 'local' to indicate destination - + Returns: Dictionary of variables to restore """ if not self.state_file.exists(): return {} - - with open(self.state_file, 'rb') as f: + + with open(self.state_file, "rb") as f: state_dict = pickle.load(f) - + # First handle imports - for import_stmt in state_dict['imports']: + for import_stmt in state_dict["imports"]: exec(import_stmt, globals()) - - return state_dict['variables'] + + return state_dict["variables"] def read_multiplexed_response(socket): @@ -81,10 +83,10 @@ def read_multiplexed_response(socket): socket.settimeout(10.0) i = 0 - while True and i<1000: + while True and i < 1000: # Stream output from socket response_data = socket.recv(4096) - responses = response_data.split(b'\x01\x00\x00\x00\x00\x00') + responses = response_data.split(b"\x01\x00\x00\x00\x00\x00") # The last non-empty chunk should be our JSON response if len(responses) > 0: @@ -92,15 +94,15 @@ def read_multiplexed_response(socket): if chunk and len(chunk.strip()) > 0: try: # Find the start of valid JSON by looking for '{' - json_start = chunk.find(b'{') + json_start = chunk.find(b"{") if json_start != -1: - decoded = chunk[json_start:].decode('utf-8') + decoded = chunk[json_start:].decode("utf-8") result = json.loads(decoded) if "output" in result: return decoded except json.JSONDecodeError: continue - i+=1 + i += 1 class DockerPythonInterpreter: @@ -113,7 +115,6 @@ class DockerPythonInterpreter: self.socket = None self.state_manager = StateManager(work_dir) - def create_interpreter_script(self) -> str: """Create the interpreter script that will run inside the container""" script = """ @@ -228,11 +229,9 @@ if __name__ == '__main__': container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}" self.create_interpreter_script() - + # Setup volume mapping - volumes = { - str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"} - } + volumes = {str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}} for container in self.client.containers.list(all=True): if container_name == container.name: @@ -241,7 +240,7 @@ if __name__ == '__main__': container.start() self.container = container break - else: # Create new container + else: # Create new container self.container = self.client.containers.run( "python:3.9", name=container_name, @@ -250,22 +249,20 @@ if __name__ == '__main__': tty=True, stdin_open=True, working_dir="/workspace", - volumes=volumes + volumes=volumes, ) # Install packages in the new container print("Installing packages...") packages = ["pandas", "numpy", "pickle5"] # Add your required packages here result = self.container.exec_run( - f"pip install {' '.join(packages)}", - workdir="/workspace" + f"pip install {' '.join(packages)}", workdir="/workspace" ) if result.exit_code != 0: print(f"Warning: Failed to install: {result.output.decode()}") else: print(f"Installed {packages}.") - if not self.wait_for_ready(self.container): raise Exception("Failed to start container") @@ -276,14 +273,12 @@ if __name__ == '__main__': stdin=True, stdout=True, stderr=True, - tty=True + tty=True, ) - + # Connect to the exec instance self.socket = self.client.api.exec_start( - self.exec_id['Id'], - socket=True, - demux=True + self.exec_id["Id"], socket=True, demux=True )._sock def _raw_execute(self, code: str) -> Tuple[str, bool]: @@ -296,14 +291,14 @@ if __name__ == '__main__': if not self.socket: raise Exception("Socket not started") - command = json.dumps({'code': code}) + '\n' + command = json.dumps({"code": code}) + "\n" self.socket.send(command.encode()) response = read_multiplexed_response(self.socket) try: result = json.loads(response) - return result['output'], result['more'] + return result["output"], result["more"] except json.JSONDecodeError: return f"Error: Invalid response from interpreter: {response}", False @@ -311,7 +306,7 @@ if __name__ == '__main__': """Get the current locals dictionary from the interpreter by pickling directly from Docker.""" pickle_path = self.work_dir / "locals.pickle" if pickle_path.exists(): - with open(pickle_path, 'rb') as f: + with open(pickle_path, "rb") as f: try: return pickle.load(f) except Exception as e: @@ -322,14 +317,11 @@ if __name__ == '__main__': def execute(self, code: str) -> Tuple[str, bool]: # Track imports before execution self.state_manager.track_imports(code) - + output, more = self._raw_execute(code) - + # Save state after execution - self.state_manager.save_state( - self.get_locals_dict(), - 'docker' - ) + self.state_manager.save_state(self.get_locals_dict(), "docker") return output, more def stop(self, remove: bool = False): @@ -338,7 +330,7 @@ if __name__ == '__main__': self.socket.close() except: pass - + if self.container: try: self.container.stop() @@ -349,22 +341,23 @@ if __name__ == '__main__': print(f"Error stopping container: {e}") raise + def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES """Execute code locally with state transfer.""" state_manager = StateManager(work_dir) - + # Track imports state_manager.track_imports(code) - + # Load state from Docker if available - locals_dict = state_manager.load_state('local') + locals_dict = state_manager.load_state("local") # Execute in a new namespace with loaded state namespace = {} namespace.update(locals_dict) - + output = evaluate_python_code( code, tools, @@ -374,7 +367,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: ) # Save state for Docker - state_manager.save_state(namespace, 'local') + state_manager.save_state(namespace, "local") return output @@ -382,38 +375,41 @@ def create_tools_regex(tool_names): # Escape any special regex characters in tool names escaped_names = [re.escape(name) for name in tool_names] # Join with | and add word boundaries - pattern = r'\b(' + '|'.join(escaped_names) + r')\b' + pattern = r"\b(" + "|".join(escaped_names) + r")\b" return re.compile(pattern) + def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter): """Execute code with automatic switching between Docker and local.""" - lines = code.split('\n') + lines = code.split("\n") current_block = [] - tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing + tool_regex = create_tools_regex( + list(tools.keys()) + ["print"] + ) # Added print for testing tools = { **BASE_PYTHON_TOOLS.copy(), **tools, } - + for line in lines: if tool_regex.search(line): # Execute accumulated Docker code if any if current_block: - output, more = interpreter.execute('\n'.join(current_block)) - print(output, end='') + output, more = interpreter.execute("\n".join(current_block)) + print(output, end="") current_block = [] - + output = execute_locally(line, work_dir, tools) if output: - print(output, end='') + print(output, end="") else: current_block.append(line) - + # Execute any remaining Docker code if current_block: - output, more = interpreter.execute('\n'.join(current_block)) - print(output, end='') + output, more = interpreter.execute("\n".join(current_block)) + print(output, end="") -__all__ = ["DockerPythonInterpreter", "execute_code"] \ No newline at end of file +__all__ = ["DockerPythonInterpreter", "execute_code"] diff --git a/src/agents/gradio_ui.py b/src/agents/gradio_ui.py index 6566853..176fda4 100644 --- a/src/agents/gradio_ui.py +++ b/src/agents/gradio_ui.py @@ -111,4 +111,5 @@ class GradioUI: demo.launch() -__all__ = ["stream_to_gradio", "GradioUI"] \ No newline at end of file + +__all__ = ["stream_to_gradio", "GradioUI"] diff --git a/src/agents/llm_engines.py b/src/agents/llm_engines.py index b4a3303..005cdf9 100644 --- a/src/agents/llm_engines.py +++ b/src/agents/llm_engines.py @@ -37,6 +37,7 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = { "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```", } + class MessageRole(str, Enum): USER = "user" ASSISTANT = "assistant" @@ -48,6 +49,7 @@ class MessageRole(str, Enum): def roles(cls): return [r.value for r in cls] + openai_role_conversions = { MessageRole.TOOL_RESPONSE: MessageRole.USER, } @@ -56,6 +58,7 @@ llama_role_conversions = { MessageRole.TOOL_RESPONSE: MessageRole.USER, } + def get_clean_message_list( message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} ): @@ -118,7 +121,7 @@ class HfEngine: messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, - max_tokens: int = 1500 + max_tokens: int = 1500, ): raise NotImplementedError @@ -276,7 +279,12 @@ class TransformersEngine(HfEngine): class OpenAIEngine: - def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None): + def __init__( + self, + model_name: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ): """Creates a LLM Engine that follows OpenAI format. Args: @@ -301,7 +309,9 @@ class OpenAIEngine: grammar: Optional[str] = None, max_tokens: int = 1500, ) -> str: - messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) + messages = get_clean_message_list( + messages, role_conversions=openai_role_conversions + ) response = self.client.chat.completions.create( model=self.model_name, @@ -337,7 +347,9 @@ class AnthropicEngine: grammar: Optional[str] = None, max_tokens: int = 1500, ) -> str: - messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) + messages = get_clean_message_list( + messages, role_conversions=openai_role_conversions + ) index_system_message, system_prompt = None, None for index, message in enumerate(messages): if message["role"] == MessageRole.SYSTEM: @@ -346,7 +358,9 @@ class AnthropicEngine: if system_prompt is None: raise Exception("No system prompt found!") - filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message] + filtered_messages = [ + message for i, message in enumerate(messages) if i != index_system_message + ] if len(filtered_messages) == 0: print("Error, no user message:", messages) assert False @@ -366,4 +380,13 @@ class AnthropicEngine: return full_response_text -__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"] \ No newline at end of file +__all__ = [ + "MessageRole", + "llama_role_conversions", + "get_clean_message_list", + "HfEngine", + "TransformersEngine", + "HfApiEngine", + "OpenAIEngine", + "AnthropicEngine", +] diff --git a/src/agents/local_python_executor.py b/src/agents/local_python_executor.py index 0db13e5..938e184 100644 --- a/src/agents/local_python_executor.py +++ b/src/agents/local_python_executor.py @@ -1000,4 +1000,5 @@ def evaluate_python_code( msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" raise InterpreterError(msg) -__all__ = ["evaluate_python_code"] \ No newline at end of file + +__all__ = ["evaluate_python_code"] diff --git a/src/agents/monitoring.py b/src/agents/monitoring.py index e1e9882..2f57268 100644 --- a/src/agents/monitoring.py +++ b/src/agents/monitoring.py @@ -44,4 +44,5 @@ class Monitor: console.print(f"- Input tokens: {self.total_input_token_count:,}") console.print(f"- Output tokens: {self.total_output_token_count:,}") -__all__ = ["Monitor"] \ No newline at end of file + +__all__ = ["Monitor"] diff --git a/src/agents/prompts.py b/src/agents/prompts.py index ac4dd25..1d5a7db 100644 --- a/src/agents/prompts.py +++ b/src/agents/prompts.py @@ -491,4 +491,10 @@ Here is my new/updated plan of action to solve the task: {plan_update} ```""" -__all__ = ["USER_PROMPT_PLAN_UPDATE", "PLAN_UPDATE_FINAL_PLAN_REDACTION", "ONESHOT_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT"] \ No newline at end of file +__all__ = [ + "USER_PROMPT_PLAN_UPDATE", + "PLAN_UPDATE_FINAL_PLAN_REDACTION", + "ONESHOT_CODE_SYSTEM_PROMPT", + "CODE_SYSTEM_PROMPT", + "JSON_SYSTEM_PROMPT", +] diff --git a/src/agents/search.py b/src/agents/search.py index 837d6b0..bdfebff 100644 --- a/src/agents/search.py +++ b/src/agents/search.py @@ -78,4 +78,5 @@ class VisitWebpageTool(Tool): except Exception as e: return f"An unexpected error occurred: {str(e)}" -__all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"] \ No newline at end of file + +__all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"] diff --git a/src/agents/tools.py b/src/agents/tools.py index 54deb9e..55e5890 100644 --- a/src/agents/tools.py +++ b/src/agents/tools.py @@ -39,12 +39,6 @@ from huggingface_hub import ( from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session from packaging import version -from transformers.dynamic_module_utils import ( - custom_object_save, - get_class_from_dynamic_module, - get_imports, -) -from transformers import AutoProcessor from transformers.utils import ( TypeHintParsingException, cached_file, @@ -62,11 +56,10 @@ logger = logging.getLogger(__name__) if is_torch_available(): - import torch + pass if is_accelerate_available(): - from accelerate import PartialState - from accelerate.utils import send_to_device + pass TOOL_CONFIG_FILE = "tool_config.json" @@ -123,12 +116,13 @@ def validate_after_init(cls, do_validate_forward: bool = True): cls.__init__ = new_init return cls + def validate_args_are_self_contained(source_code): """Validates that all names in forward method are properly defined. In particular it will check that all imports are done within the function.""" print("CODDDD", source_code) tree = ast.parse(textwrap.dedent(source_code)) - + # Get function arguments func_node = tree.body[0] arg_names = {arg.arg for arg in func_node.args.args} | {"kwargs"} @@ -147,10 +141,10 @@ def validate_args_are_self_contained(source_code): for name in node.names: actual_name = name.asname or name.name self.imports[actual_name] = (name.name, actual_name) - + def visit_ImportFrom(self, node): """Handle from imports like 'from datetime import datetime'.""" - module = node.module or '' + module = node.module or "" for name in node.names: actual_name = name.asname or name.name self.from_imports[actual_name] = (module, name.name, actual_name) @@ -161,7 +155,7 @@ def validate_args_are_self_contained(source_code): if isinstance(target, ast.Name): self.assigned_names.add(target.id) self.visit(node.value) - + def visit_AnnAssign(self, node): """Track annotated assignments.""" if isinstance(node.target, ast.Name): @@ -179,46 +173,48 @@ def validate_args_are_self_contained(source_code): if isinstance(elt, ast.Name): names.add(elt.id) return names - + def visit_For(self, node): """Track for-loop target variables and handle enumerate specially.""" # Add names from the target target_names = self._handle_for_target(node.target) self.assigned_names.update(target_names) - + # Special handling for enumerate - if (isinstance(node.iter, ast.Call) and - isinstance(node.iter.func, ast.Name) and - node.iter.func.id == 'enumerate'): - # For enumerate, if we have "for i, x in enumerate(...)", + if ( + isinstance(node.iter, ast.Call) + and isinstance(node.iter.func, ast.Name) + and node.iter.func.id == "enumerate" + ): + # For enumerate, if we have "for i, x in enumerate(...)", # both i and x should be marked as assigned if isinstance(node.target, ast.Tuple): for elt in node.target.elts: if isinstance(elt, ast.Name): self.assigned_names.add(elt.id) - + # Visit the rest of the node self.generic_visit(node) def visit_Name(self, node): - if (isinstance(node.ctx, ast.Load) and not ( - node.id == "tool" or - node.id in builtin_names or - node.id in arg_names or - node.id == 'self' or - node.id in self.assigned_names - )): + if isinstance(node.ctx, ast.Load) and not ( + node.id == "tool" + or node.id in builtin_names + or node.id in arg_names + or node.id == "self" + or node.id in self.assigned_names + ): if node.id not in self.from_imports and node.id not in self.imports: self.undefined_names.add(node.id) - + def visit_Attribute(self, node): # Skip self.something - if not (isinstance(node.value, ast.Name) and node.value.id == 'self'): + if not (isinstance(node.value, ast.Name) and node.value.id == "self"): self.generic_visit(node) - + checker = NameChecker() checker.visit(tree) - + if checker.undefined_names: raise ValueError( f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}. @@ -226,6 +222,7 @@ def validate_args_are_self_contained(source_code): """ ) + AUTHORIZED_TYPES = [ "string", "boolean", @@ -273,7 +270,6 @@ class Tool: super().__init_subclass__(**kwargs) validate_after_init(cls, do_validate_forward=False) - def validate_arguments(self, do_validate_forward: bool = True): required_attributes = { "description": str, @@ -359,14 +355,14 @@ class {class_name}(Tool): def add_self_argument(source_code: str) -> str: """Add 'self' as first argument to a function definition if not present.""" - pattern = r'def forward\(((?!self)[^)]*)\)' - + pattern = r"def forward\(((?!self)[^)]*)\)" + def replacement(match): args = match.group(1).strip() if args: # If there are other arguments - return f'def forward(self, {args})' - return 'def forward(self)' - + return f"def forward(self, {args})" + return "def forward(self)" + return re.sub(pattern, replacement, source_code) forward_source_code = forward_source_code.replace(self.name, "forward") @@ -391,11 +387,7 @@ class {class_name}(Tool): # Save app file app_file = os.path.join(output_dir, "app.py") with open(app_file, "w", encoding="utf-8") as f: - f.write( - APP_FILE_TEMPLATE.format( - class_name=class_name - ) - ) + f.write(APP_FILE_TEMPLATE.format(class_name=class_name)) # Save requirements file requirements_file = os.path.join(output_dir, "requirements.txt") @@ -457,7 +449,7 @@ class {class_name}(Tool): self.save(work_dir) print(work_dir) with open(work_dir + "/tool.py", "r") as f: - print('\n'.join(f.readlines())) + print("\n".join(f.readlines())) logger.info( f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" ) @@ -544,8 +536,8 @@ class {class_name}(Tool): ) with open(resolved_tool_file, encoding="utf-8") as reader: - tool_code = "".join(reader.readlines()) - + tool_code = "".join(reader.readlines()) + # Find the Tool subclass in the namespace with tempfile.TemporaryDirectory() as temp_dir: # Save the code to a file @@ -569,13 +561,12 @@ class {class_name}(Tool): if tool_class is None: raise ValueError("No Tool subclass found in the code.") - + if not isinstance(tool_class.inputs, dict): tool_class.inputs = ast.literal_eval(tool_class.inputs) return tool_class(**kwargs) - @staticmethod def from_space( space_id: str, @@ -702,7 +693,11 @@ class {class_name}(Tool): return output return SpaceToolWrapper( - space_id=space_id, name=name, description=description, api_name=api_name, token=token + space_id=space_id, + name=name, + description=description, + api_name=api_name, + token=token, ) @staticmethod @@ -855,12 +850,12 @@ TOOL_MAPPING = { def load_tool( - task_or_repo_id, - model_repo_id: Optional[str] = None, - token: Optional[str] = None, - trust_remote_code: bool=False, - **kwargs - ): + task_or_repo_id, + model_repo_id: Optional[str] = None, + token: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, +): """ Main function to quickly load a tool, be it on the Hub or in the Transformers library. @@ -909,7 +904,11 @@ def load_tool( f"code that you have checked." ) return Tool.from_hub( - task_or_repo_id, model_repo_id=model_repo_id, token=token, trust_remote_code=trust_remote_code, **kwargs + task_or_repo_id, + model_repo_id=model_repo_id, + token=token, + trust_remote_code=trust_remote_code, + **kwargs, ) @@ -1028,7 +1027,7 @@ def tool(tool_function: Callable) -> Tool: raise TypeHintParsingException( "Tool return type not found: make sure your function has a return type hint!" ) - class_name = ''.join([el.title() for el in parameters['name'].split('_')]) + class_name = "".join([el.title() for el in parameters["name"].split("_")]) if parameters["return"]["type"] == "object": parameters["return"]["type"] = "any" @@ -1086,7 +1085,9 @@ class Toolbox: """Get all tools currently in the toolbox""" return self._tools - def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str: + def show_tool_descriptions( + self, tool_description_template: Optional[str] = None + ) -> str: """ Returns the description of all tools in the toolbox @@ -1151,4 +1152,12 @@ class Toolbox: toolbox_description += f"\t{tool.name}: {tool.description}\n" return toolbox_description -__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"] + +__all__ = [ + "AUTHORIZED_TYPES", + "Tool", + "tool", + "load_tool", + "launch_gradio_demo", + "Toolbox", +] diff --git a/src/agents/types.py b/src/agents/types.py index 99007c8..33970d7 100644 --- a/src/agents/types.py +++ b/src/agents/types.py @@ -267,4 +267,5 @@ def handle_agent_outputs(output, output_type=None): return _v(output) return output -__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] \ No newline at end of file + +__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] diff --git a/src/agents/utils.py b/src/agents/utils.py index 47e9e01..081f6f2 100644 --- a/src/agents/utils.py +++ b/src/agents/utils.py @@ -18,6 +18,7 @@ import json import re from typing import Tuple, Dict, Union import ast +from rich.console import Console from transformers.utils.import_utils import _is_package_available @@ -28,8 +29,6 @@ def is_pygments_available(): return _pygments_available -from rich.console import Console - console = Console() @@ -110,21 +109,23 @@ def truncate_content( + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] ) - + + class ImportFinder(ast.NodeVisitor): def __init__(self): self.packages = set() - + def visit_Import(self, node): for alias in node.names: # Get the base package name (before any dots) - base_package = alias.name.split('.')[0] + base_package = alias.name.split(".")[0] self.packages.add(base_package) def visit_ImportFrom(self, node): if node.module: # for "from x import y" statements # Get the base package name (before any dots) - base_package = node.module.split('.')[0] + base_package = node.module.split(".")[0] self.packages.add(base_package) -__all__ = [] \ No newline at end of file + +__all__ = [] diff --git a/tests/test_agents.py b/tests/test_agents.py index b48a32e..16b281f 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -27,12 +27,13 @@ from agents.agents import ( CodeAgent, JsonAgent, Toolbox, - ToolCall + ToolCall, ) from agents.tools import tool from agents.default_tools import PythonInterpreterTool from transformers.testing_utils import get_tests_dir + def get_new_path(suffix="") -> str: directory = tempfile.mkdtemp() return os.path.join(directory, str(uuid.uuid4()) + suffix) @@ -60,6 +61,7 @@ Action: } """ + def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) @@ -82,6 +84,7 @@ Action: } """ + def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: @@ -179,9 +182,7 @@ class AgentTests(unittest.TestCase): assert output == "7.2904" def test_fake_json_agent(self): - agent = JsonAgent( - tools=[PythonInterpreterTool()], llm_engine=fake_json_llm - ) + agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, str) assert output == "7.2904" @@ -209,9 +210,7 @@ Action: Args: prompt: The prompt """ - return Image.open( - Path(get_tests_dir("fixtures")) / "000000039769.png" - ) + return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png") agent = JsonAgent( tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image @@ -221,9 +220,7 @@ Action: assert isinstance(agent.state["image.png"], Image.Image) def test_fake_code_agent(self): - agent = CodeAgent( - tools=[PythonInterpreterTool()], llm_engine=fake_code_llm - ) + agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, float) assert output == 7.2904 @@ -234,9 +231,7 @@ Action: ) def test_reset_conversations(self): - agent = CodeAgent( - tools=[PythonInterpreterTool()], llm_engine=fake_code_llm - ) + agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) output = agent.run("What is 2 multiplied by 3.6452?", reset=True) assert output == 7.2904 assert len(agent.logs) == 4 @@ -299,9 +294,7 @@ Action: # check that python_interpreter base tool does not get added to code agents agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True) - assert ( - len(agent.toolbox.tools) == 2 - ) # added final_answer tool + search + assert len(agent.toolbox.tools) == 2 # added final_answer tool + search def test_function_persistence_across_steps(self): agent = CodeAgent( diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index 772eac5..e5b46ff 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -25,6 +25,7 @@ from pathlib import Path from typing import List from dotenv import load_dotenv + class SubprocessCallException(Exception): pass @@ -37,10 +38,10 @@ def run_command(command: List[str], return_stdout=False, env=None): for i, c in enumerate(command): if isinstance(c, Path): command[i] = str(c) - + if env is None: env = os.environ.copy() - + try: output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env) if return_stdout: @@ -55,14 +56,14 @@ def run_command(command: List[str], return_stdout=False, env=None): class DocCodeExtractor: """Handles extraction and validation of Python code from markdown files.""" - + @staticmethod def extract_python_code(content: str) -> List[str]: """Extract Python code blocks from markdown content.""" - pattern = r'```(?:python|py)\n(.*?)\n```' + pattern = r"```(?:python|py)\n(.*?)\n```" matches = re.finditer(pattern, content, re.DOTALL) return [match.group(1).strip() for match in matches] - + @staticmethod def create_test_script(code_blocks: List[str], tmp_dir: str) -> Path: """Create a temporary Python script from code blocks.""" @@ -74,13 +75,13 @@ class DocCodeExtractor: with open(tmp_file, "w", encoding="utf-8") as f: f.write(combined_code) - + return tmp_file class TestDocs: """Test case for documentation code testing.""" - + @classmethod def setup_class(cls): cls._tmpdir = tempfile.mkdtemp() @@ -93,7 +94,7 @@ class TestDocs: load_dotenv() cls.hf_token = os.getenv("HF_TOKEN") - + cls.md_files = list(cls.docs_dir.rglob("*.md")) if not cls.md_files: raise ValueError(f"No markdown files found in {cls.docs_dir}") @@ -107,29 +108,38 @@ class TestDocs: """Test a single documentation file.""" with open(doc_path, "r", encoding="utf-8") as f: content = f.read() - + code_blocks = self.extractor.extract_python_code(content) if not code_blocks: pytest.skip(f"No Python code blocks found in {doc_path.name}") - + # Validate syntax of each block individually by parsing it for i, block in enumerate(code_blocks, 1): ast.parse(block) - + # Create and execute test script try: - excluded_snippets = ["ToolCollection", "image_generation_tool", "from_langchain"] + excluded_snippets = [ + "ToolCollection", + "image_generation_tool", + "from_langchain", + ] code_blocks = [ - block.replace("", self.hf_token) for block in code_blocks - if not any([snippet in block for snippet in excluded_snippets]) # Exclude these tools that take longer to run and add dependencies + block.replace("", self.hf_token) + for block in code_blocks + if not any( + [snippet in block for snippet in excluded_snippets] + ) # Exclude these tools that take longer to run and add dependencies ] test_script = self.extractor.create_test_script(code_blocks, self._tmpdir) run_command(self.launch_args + [str(test_script)]) - + except SubprocessCallException as e: pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}") - except Exception as e: - pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}") + except Exception: + pytest.fail( + f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}" + ) @pytest.fixture(autouse=True) def _setup(self): @@ -152,7 +162,5 @@ def pytest_generate_tests(metafunc): # Parameterize with the markdown files metafunc.parametrize( - "doc_path", - test_class.md_files, - ids=[f.stem for f in test_class.md_files] - ) \ No newline at end of file + "doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files] + ) diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 96d6c8f..5b6e50f 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -20,9 +20,10 @@ import numpy as np from PIL import Image from transformers import is_torch_available -from agents.types import AGENT_TYPE_MAPPING -from agents.default_tools import FinalAnswerTool from transformers.testing_utils import get_tests_dir, require_torch +from agents.types import AGENT_TYPE_MAPPING + +from agents.default_tools import FinalAnswerTool from .test_tools_common import ToolTesterMixin diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py index 36c70be..ae3b622 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools_common.py @@ -98,6 +98,7 @@ class ToolTesterMixin: agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, agent_type)) + class ToolTests(unittest.TestCase): def test_tool_init_with_decorator(self): @tool @@ -163,40 +164,46 @@ class ToolTests(unittest.TestCase): assert coolfunc.output_type == "number" assert "docstring has no description for the argument" in str(e) - def test_tool_definition_needs_imports_in_function(self): + def test_tool_definition_raises_error_imports_outside_function(self): with pytest.raises(Exception) as e: from datetime import datetime + @tool def get_current_time() -> str: """ Gets the current time. """ return str(datetime.now()) + assert "datetime" in str(e) # Also test with classic definition with pytest.raises(Exception) as e: + class GetCurrentTimeTool(Tool): - name="get_current_time_tool" - description="Gets the current time" + name = "get_current_time_tool" + description = "Gets the current time" inputs = {} output_type = "string" def forward(self): return str(datetime.now()) + assert "datetime" in str(e) + def test_tool_definition_raises_no_error_imports_in_function(self): @tool def get_current_time() -> str: """ Gets the current time. """ from datetime import datetime + return str(datetime.now()) - + class GetCurrentTimeTool(Tool): - name="get_current_time_tool" - description="Gets the current time" + name = "get_current_time_tool" + description = "Gets the current time" inputs = {} output_type = "string" diff --git a/tests/test_utils.py b/tests/test_utils.py index 12b3266..34ca2db 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import tempfile from pathlib import Path + def str_to_bool(value) -> int: """ Converts a string representation of truth to `True` (1) or `False` (0). @@ -28,10 +29,13 @@ def get_int_from_env(env_keys, default): return val return default + def parse_flag_from_env(key, default=False): """Returns truthy value for `key` from the env if available else the default.""" value = os.environ.get(key, str(default)) - return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int... + return ( + str_to_bool(value) == 1 + ) # As its name indicates `str_to_bool` actually returns an int... _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) @@ -80,4 +84,4 @@ class TempDirTestCase(unittest.TestCase): if path.is_file(): path.unlink() elif path.is_dir(): - shutil.rmtree(path) \ No newline at end of file + shutil.rmtree(path)