Formatting
This commit is contained in:
		
							parent
							
								
									1751bf03ac
								
							
						
					
					
						commit
						06066437fd
					
				|  | @ -18,11 +18,7 @@ __version__ = "0.1.0" | ||||||
| 
 | 
 | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||||
| 
 | 
 | ||||||
| from transformers.utils import ( | from transformers.utils import _LazyModule | ||||||
|     OptionalDependencyNotAvailable, |  | ||||||
|     _LazyModule, |  | ||||||
|     is_torch_available, |  | ||||||
| ) |  | ||||||
| from transformers.utils.import_utils import define_import_structure | from transformers.utils.import_utils import define_import_structure | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -43,4 +39,6 @@ else: | ||||||
|     import sys |     import sys | ||||||
| 
 | 
 | ||||||
|     _file = globals()["__file__"] |     _file = globals()["__file__"] | ||||||
|     sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |     sys.modules[__name__] = _LazyModule( | ||||||
|  |         __name__, _file, define_import_structure(_file), module_spec=__spec__ | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  | @ -79,8 +79,9 @@ class AgentGenerationError(AgentError): | ||||||
| 
 | 
 | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ToolCall(): | class ToolCall: | ||||||
|     tool_name: str |     tool_name: str | ||||||
|     tool_arguments: Any |     tool_arguments: Any | ||||||
| 
 | 
 | ||||||
|  | @ -146,13 +147,17 @@ Here is a list of the team members that you can call:""" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def format_prompt_with_managed_agents_descriptions( | def format_prompt_with_managed_agents_descriptions( | ||||||
|     prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None |     prompt_template, | ||||||
|  |     managed_agents, | ||||||
|  |     agent_descriptions_placeholder: Optional[str] = None, | ||||||
| ) -> str: | ) -> str: | ||||||
|     if agent_descriptions_placeholder is None: |     if agent_descriptions_placeholder is None: | ||||||
|         agent_descriptions_placeholder = "{{managed_agents_descriptions}}" |         agent_descriptions_placeholder = "{{managed_agents_descriptions}}" | ||||||
|     if agent_descriptions_placeholder not in prompt_template: |     if agent_descriptions_placeholder not in prompt_template: | ||||||
|         print("PROMPT TEMPLLL", prompt_template) |         print("PROMPT TEMPLLL", prompt_template) | ||||||
|         raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'") |         raise ValueError( | ||||||
|  |             f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'" | ||||||
|  |         ) | ||||||
|     if len(managed_agents.keys()) > 0: |     if len(managed_agents.keys()) > 0: | ||||||
|         return prompt_template.replace( |         return prompt_template.replace( | ||||||
|             agent_descriptions_placeholder, show_agents_descriptions(managed_agents) |             agent_descriptions_placeholder, show_agents_descriptions(managed_agents) | ||||||
|  | @ -970,7 +975,9 @@ class CodeAgent(ReactAgent): | ||||||
|             error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" |             error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" | ||||||
|             raise AgentParsingError(error_msg) |             raise AgentParsingError(error_msg) | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action) |         log_entry.tool_call = ToolCall( | ||||||
|  |             tool_name="python_interpreter", tool_arguments=code_action | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|  | @ -1075,4 +1082,13 @@ And even if your task resolution is not successful, please return as much contex | ||||||
|         else: |         else: | ||||||
|             return output |             return output | ||||||
| 
 | 
 | ||||||
| __all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"] | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     "AgentError", | ||||||
|  |     "BaseAgent", | ||||||
|  |     "ManagedAgent", | ||||||
|  |     "ReactAgent", | ||||||
|  |     "CodeAgent", | ||||||
|  |     "JsonAgent", | ||||||
|  |     "Toolbox", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | @ -127,7 +127,10 @@ class PythonInterpreterTool(Tool): | ||||||
|     name = "python_interpreter" |     name = "python_interpreter" | ||||||
|     description = "This is a tool that evaluates python code. It can be used to perform calculations." |     description = "This is a tool that evaluates python code. It can be used to perform calculations." | ||||||
|     inputs = { |     inputs = { | ||||||
|         "code": {"type": "string", "description": "The python code to run in interpreter"} |         "code": { | ||||||
|  |             "type": "string", | ||||||
|  |             "description": "The python code to run in interpreter", | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|  | @ -186,4 +189,5 @@ class UserInputTool(Tool): | ||||||
|         user_input = input(f"{question} => ") |         user_input = input(f"{question} => ") | ||||||
|         return user_input |         return user_input | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"] | __all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"] | ||||||
|  | @ -9,12 +9,13 @@ from typing import Optional, Dict, Tuple, Set, Any | ||||||
| import types | import types | ||||||
| from .default_tools import BASE_PYTHON_TOOLS | from .default_tools import BASE_PYTHON_TOOLS | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class StateManager: | class StateManager: | ||||||
|     def __init__(self, work_dir: Path): |     def __init__(self, work_dir: Path): | ||||||
|         self.work_dir = work_dir |         self.work_dir = work_dir | ||||||
|         self.state_file = work_dir / "interpreter_state.pickle" |         self.state_file = work_dir / "interpreter_state.pickle" | ||||||
|         self.imports_file = work_dir / "imports.txt" |         self.imports_file = work_dir / "imports.txt" | ||||||
|         self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$') |         self.import_pattern = re.compile(r"^(?:from\s+[\w.]+\s+)?import\s+.+$") | ||||||
|         self.imports: Set[str] = set() |         self.imports: Set[str] = set() | ||||||
| 
 | 
 | ||||||
|     def is_import_statement(self, code: str) -> bool: |     def is_import_statement(self, code: str) -> bool: | ||||||
|  | @ -23,7 +24,7 @@ class StateManager: | ||||||
| 
 | 
 | ||||||
|     def track_imports(self, code: str): |     def track_imports(self, code: str): | ||||||
|         """Track import statements for later use.""" |         """Track import statements for later use.""" | ||||||
|         for line in code.split('\n'): |         for line in code.split("\n"): | ||||||
|             if self.is_import_statement(line.strip()): |             if self.is_import_statement(line.strip()): | ||||||
|                 self.imports.add(line.strip()) |                 self.imports.add(line.strip()) | ||||||
| 
 | 
 | ||||||
|  | @ -37,20 +38,21 @@ class StateManager: | ||||||
|         """ |         """ | ||||||
|         # Filter out modules, functions, and special variables |         # Filter out modules, functions, and special variables | ||||||
|         state_dict = { |         state_dict = { | ||||||
|             'variables': { |             "variables": { | ||||||
|                 k: v for k, v in locals_dict.items() |                 k: v | ||||||
|  |                 for k, v in locals_dict.items() | ||||||
|                 if not ( |                 if not ( | ||||||
|                     k.startswith('_') |                     k.startswith("_") | ||||||
|                     or callable(v) |                     or callable(v) | ||||||
|                     or isinstance(v, type) |                     or isinstance(v, type) | ||||||
|                     or isinstance(v, types.ModuleType) |                     or isinstance(v, types.ModuleType) | ||||||
|                 ) |                 ) | ||||||
|             }, |             }, | ||||||
|             'imports': list(self.imports), |             "imports": list(self.imports), | ||||||
|             'source': executor |             "source": executor, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         with open(self.state_file, 'wb') as f: |         with open(self.state_file, "wb") as f: | ||||||
|             pickle.dump(state_dict, f) |             pickle.dump(state_dict, f) | ||||||
| 
 | 
 | ||||||
|     def load_state(self, executor: str) -> Dict[str, Any]: |     def load_state(self, executor: str) -> Dict[str, Any]: | ||||||
|  | @ -66,14 +68,14 @@ class StateManager: | ||||||
|         if not self.state_file.exists(): |         if not self.state_file.exists(): | ||||||
|             return {} |             return {} | ||||||
| 
 | 
 | ||||||
|         with open(self.state_file, 'rb') as f: |         with open(self.state_file, "rb") as f: | ||||||
|             state_dict = pickle.load(f) |             state_dict = pickle.load(f) | ||||||
| 
 | 
 | ||||||
|         # First handle imports |         # First handle imports | ||||||
|         for import_stmt in state_dict['imports']: |         for import_stmt in state_dict["imports"]: | ||||||
|             exec(import_stmt, globals()) |             exec(import_stmt, globals()) | ||||||
| 
 | 
 | ||||||
|         return state_dict['variables'] |         return state_dict["variables"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def read_multiplexed_response(socket): | def read_multiplexed_response(socket): | ||||||
|  | @ -84,7 +86,7 @@ def read_multiplexed_response(socket): | ||||||
|     while True and i < 1000: |     while True and i < 1000: | ||||||
|         # Stream output from socket |         # Stream output from socket | ||||||
|         response_data = socket.recv(4096) |         response_data = socket.recv(4096) | ||||||
|         responses = response_data.split(b'\x01\x00\x00\x00\x00\x00') |         responses = response_data.split(b"\x01\x00\x00\x00\x00\x00") | ||||||
| 
 | 
 | ||||||
|         # The last non-empty chunk should be our JSON response |         # The last non-empty chunk should be our JSON response | ||||||
|         if len(responses) > 0: |         if len(responses) > 0: | ||||||
|  | @ -92,9 +94,9 @@ def read_multiplexed_response(socket): | ||||||
|                 if chunk and len(chunk.strip()) > 0: |                 if chunk and len(chunk.strip()) > 0: | ||||||
|                     try: |                     try: | ||||||
|                         # Find the start of valid JSON by looking for '{' |                         # Find the start of valid JSON by looking for '{' | ||||||
|                         json_start = chunk.find(b'{') |                         json_start = chunk.find(b"{") | ||||||
|                         if json_start != -1: |                         if json_start != -1: | ||||||
|                             decoded = chunk[json_start:].decode('utf-8') |                             decoded = chunk[json_start:].decode("utf-8") | ||||||
|                             result = json.loads(decoded) |                             result = json.loads(decoded) | ||||||
|                             if "output" in result: |                             if "output" in result: | ||||||
|                                 return decoded |                                 return decoded | ||||||
|  | @ -113,7 +115,6 @@ class DockerPythonInterpreter: | ||||||
|         self.socket = None |         self.socket = None | ||||||
|         self.state_manager = StateManager(work_dir) |         self.state_manager = StateManager(work_dir) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     def create_interpreter_script(self) -> str: |     def create_interpreter_script(self) -> str: | ||||||
|         """Create the interpreter script that will run inside the container""" |         """Create the interpreter script that will run inside the container""" | ||||||
|         script = """ |         script = """ | ||||||
|  | @ -230,9 +231,7 @@ if __name__ == '__main__': | ||||||
|         self.create_interpreter_script() |         self.create_interpreter_script() | ||||||
| 
 | 
 | ||||||
|         # Setup volume mapping |         # Setup volume mapping | ||||||
|         volumes = { |         volumes = {str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}} | ||||||
|             str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"} |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         for container in self.client.containers.list(all=True): |         for container in self.client.containers.list(all=True): | ||||||
|             if container_name == container.name: |             if container_name == container.name: | ||||||
|  | @ -250,22 +249,20 @@ if __name__ == '__main__': | ||||||
|                 tty=True, |                 tty=True, | ||||||
|                 stdin_open=True, |                 stdin_open=True, | ||||||
|                 working_dir="/workspace", |                 working_dir="/workspace", | ||||||
|                 volumes=volumes |                 volumes=volumes, | ||||||
|             ) |             ) | ||||||
|             # Install packages in the new container |             # Install packages in the new container | ||||||
|             print("Installing packages...") |             print("Installing packages...") | ||||||
|             packages = ["pandas", "numpy", "pickle5"]  # Add your required packages here |             packages = ["pandas", "numpy", "pickle5"]  # Add your required packages here | ||||||
| 
 | 
 | ||||||
|             result = self.container.exec_run( |             result = self.container.exec_run( | ||||||
|                 f"pip install {' '.join(packages)}", |                 f"pip install {' '.join(packages)}", workdir="/workspace" | ||||||
|                 workdir="/workspace" |  | ||||||
|             ) |             ) | ||||||
|             if result.exit_code != 0: |             if result.exit_code != 0: | ||||||
|                 print(f"Warning: Failed to install: {result.output.decode()}") |                 print(f"Warning: Failed to install: {result.output.decode()}") | ||||||
|             else: |             else: | ||||||
|                 print(f"Installed {packages}.") |                 print(f"Installed {packages}.") | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|         if not self.wait_for_ready(self.container): |         if not self.wait_for_ready(self.container): | ||||||
|             raise Exception("Failed to start container") |             raise Exception("Failed to start container") | ||||||
| 
 | 
 | ||||||
|  | @ -276,14 +273,12 @@ if __name__ == '__main__': | ||||||
|             stdin=True, |             stdin=True, | ||||||
|             stdout=True, |             stdout=True, | ||||||
|             stderr=True, |             stderr=True, | ||||||
|             tty=True |             tty=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Connect to the exec instance |         # Connect to the exec instance | ||||||
|         self.socket = self.client.api.exec_start( |         self.socket = self.client.api.exec_start( | ||||||
|             self.exec_id['Id'], |             self.exec_id["Id"], socket=True, demux=True | ||||||
|             socket=True, |  | ||||||
|             demux=True |  | ||||||
|         )._sock |         )._sock | ||||||
| 
 | 
 | ||||||
|     def _raw_execute(self, code: str) -> Tuple[str, bool]: |     def _raw_execute(self, code: str) -> Tuple[str, bool]: | ||||||
|  | @ -296,14 +291,14 @@ if __name__ == '__main__': | ||||||
|         if not self.socket: |         if not self.socket: | ||||||
|             raise Exception("Socket not started") |             raise Exception("Socket not started") | ||||||
| 
 | 
 | ||||||
|         command = json.dumps({'code': code}) + '\n' |         command = json.dumps({"code": code}) + "\n" | ||||||
|         self.socket.send(command.encode()) |         self.socket.send(command.encode()) | ||||||
| 
 | 
 | ||||||
|         response = read_multiplexed_response(self.socket) |         response = read_multiplexed_response(self.socket) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             result = json.loads(response) |             result = json.loads(response) | ||||||
|             return result['output'], result['more'] |             return result["output"], result["more"] | ||||||
|         except json.JSONDecodeError: |         except json.JSONDecodeError: | ||||||
|             return f"Error: Invalid response from interpreter: {response}", False |             return f"Error: Invalid response from interpreter: {response}", False | ||||||
| 
 | 
 | ||||||
|  | @ -311,7 +306,7 @@ if __name__ == '__main__': | ||||||
|         """Get the current locals dictionary from the interpreter by pickling directly from Docker.""" |         """Get the current locals dictionary from the interpreter by pickling directly from Docker.""" | ||||||
|         pickle_path = self.work_dir / "locals.pickle" |         pickle_path = self.work_dir / "locals.pickle" | ||||||
|         if pickle_path.exists(): |         if pickle_path.exists(): | ||||||
|             with open(pickle_path, 'rb') as f: |             with open(pickle_path, "rb") as f: | ||||||
|                 try: |                 try: | ||||||
|                     return pickle.load(f) |                     return pickle.load(f) | ||||||
|                 except Exception as e: |                 except Exception as e: | ||||||
|  | @ -326,10 +321,7 @@ if __name__ == '__main__': | ||||||
|         output, more = self._raw_execute(code) |         output, more = self._raw_execute(code) | ||||||
| 
 | 
 | ||||||
|         # Save state after execution |         # Save state after execution | ||||||
|         self.state_manager.save_state( |         self.state_manager.save_state(self.get_locals_dict(), "docker") | ||||||
|             self.get_locals_dict(), |  | ||||||
|             'docker' |  | ||||||
|         ) |  | ||||||
|         return output, more |         return output, more | ||||||
| 
 | 
 | ||||||
|     def stop(self, remove: bool = False): |     def stop(self, remove: bool = False): | ||||||
|  | @ -349,6 +341,7 @@ if __name__ == '__main__': | ||||||
|                 print(f"Error stopping container: {e}") |                 print(f"Error stopping container: {e}") | ||||||
|                 raise |                 raise | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: | def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: | ||||||
|     from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES |     from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES | ||||||
| 
 | 
 | ||||||
|  | @ -359,7 +352,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: | ||||||
|     state_manager.track_imports(code) |     state_manager.track_imports(code) | ||||||
| 
 | 
 | ||||||
|     # Load state from Docker if available |     # Load state from Docker if available | ||||||
|     locals_dict = state_manager.load_state('local') |     locals_dict = state_manager.load_state("local") | ||||||
| 
 | 
 | ||||||
|     # Execute in a new namespace with loaded state |     # Execute in a new namespace with loaded state | ||||||
|     namespace = {} |     namespace = {} | ||||||
|  | @ -374,7 +367,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     # Save state for Docker |     # Save state for Docker | ||||||
|     state_manager.save_state(namespace, 'local') |     state_manager.save_state(namespace, "local") | ||||||
|     return output |     return output | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -382,14 +375,17 @@ def create_tools_regex(tool_names): | ||||||
|     # Escape any special regex characters in tool names |     # Escape any special regex characters in tool names | ||||||
|     escaped_names = [re.escape(name) for name in tool_names] |     escaped_names = [re.escape(name) for name in tool_names] | ||||||
|     # Join with | and add word boundaries |     # Join with | and add word boundaries | ||||||
|     pattern = r'\b(' + '|'.join(escaped_names) + r')\b' |     pattern = r"\b(" + "|".join(escaped_names) + r")\b" | ||||||
|     return re.compile(pattern) |     return re.compile(pattern) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter): | def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter): | ||||||
|     """Execute code with automatic switching between Docker and local.""" |     """Execute code with automatic switching between Docker and local.""" | ||||||
|     lines = code.split('\n') |     lines = code.split("\n") | ||||||
|     current_block = [] |     current_block = [] | ||||||
|     tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing |     tool_regex = create_tools_regex( | ||||||
|  |         list(tools.keys()) + ["print"] | ||||||
|  |     )  # Added print for testing | ||||||
| 
 | 
 | ||||||
|     tools = { |     tools = { | ||||||
|         **BASE_PYTHON_TOOLS.copy(), |         **BASE_PYTHON_TOOLS.copy(), | ||||||
|  | @ -400,20 +396,20 @@ def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter): | ||||||
|         if tool_regex.search(line): |         if tool_regex.search(line): | ||||||
|             # Execute accumulated Docker code if any |             # Execute accumulated Docker code if any | ||||||
|             if current_block: |             if current_block: | ||||||
|                 output, more = interpreter.execute('\n'.join(current_block)) |                 output, more = interpreter.execute("\n".join(current_block)) | ||||||
|                 print(output, end='') |                 print(output, end="") | ||||||
|                 current_block = [] |                 current_block = [] | ||||||
| 
 | 
 | ||||||
|             output = execute_locally(line, work_dir, tools) |             output = execute_locally(line, work_dir, tools) | ||||||
|             if output: |             if output: | ||||||
|                 print(output, end='') |                 print(output, end="") | ||||||
|         else: |         else: | ||||||
|             current_block.append(line) |             current_block.append(line) | ||||||
| 
 | 
 | ||||||
|     # Execute any remaining Docker code |     # Execute any remaining Docker code | ||||||
|     if current_block: |     if current_block: | ||||||
|         output, more = interpreter.execute('\n'.join(current_block)) |         output, more = interpreter.execute("\n".join(current_block)) | ||||||
|         print(output, end='') |         print(output, end="") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = ["DockerPythonInterpreter", "execute_code"] | __all__ = ["DockerPythonInterpreter", "execute_code"] | ||||||
|  | @ -111,4 +111,5 @@ class GradioUI: | ||||||
| 
 | 
 | ||||||
|         demo.launch() |         demo.launch() | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["stream_to_gradio", "GradioUI"] | __all__ = ["stream_to_gradio", "GradioUI"] | ||||||
|  | @ -37,6 +37,7 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = { | ||||||
|     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", |     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class MessageRole(str, Enum): | class MessageRole(str, Enum): | ||||||
|     USER = "user" |     USER = "user" | ||||||
|     ASSISTANT = "assistant" |     ASSISTANT = "assistant" | ||||||
|  | @ -48,6 +49,7 @@ class MessageRole(str, Enum): | ||||||
|     def roles(cls): |     def roles(cls): | ||||||
|         return [r.value for r in cls] |         return [r.value for r in cls] | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| openai_role_conversions = { | openai_role_conversions = { | ||||||
|     MessageRole.TOOL_RESPONSE: MessageRole.USER, |     MessageRole.TOOL_RESPONSE: MessageRole.USER, | ||||||
| } | } | ||||||
|  | @ -56,6 +58,7 @@ llama_role_conversions = { | ||||||
|     MessageRole.TOOL_RESPONSE: MessageRole.USER, |     MessageRole.TOOL_RESPONSE: MessageRole.USER, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def get_clean_message_list( | def get_clean_message_list( | ||||||
|     message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} |     message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} | ||||||
| ): | ): | ||||||
|  | @ -118,7 +121,7 @@ class HfEngine: | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500 |         max_tokens: int = 1500, | ||||||
|     ): |     ): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|  | @ -276,7 +279,12 @@ class TransformersEngine(HfEngine): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class OpenAIEngine: | class OpenAIEngine: | ||||||
|     def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None): |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_name: Optional[str] = None, | ||||||
|  |         api_key: Optional[str] = None, | ||||||
|  |         base_url: Optional[str] = None, | ||||||
|  |     ): | ||||||
|         """Creates a LLM Engine that follows OpenAI format. |         """Creates a LLM Engine that follows OpenAI format. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|  | @ -301,7 +309,9 @@ class OpenAIEngine: | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=openai_role_conversions | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         response = self.client.chat.completions.create( |         response = self.client.chat.completions.create( | ||||||
|             model=self.model_name, |             model=self.model_name, | ||||||
|  | @ -337,7 +347,9 @@ class AnthropicEngine: | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=openai_role_conversions | ||||||
|  |         ) | ||||||
|         index_system_message, system_prompt = None, None |         index_system_message, system_prompt = None, None | ||||||
|         for index, message in enumerate(messages): |         for index, message in enumerate(messages): | ||||||
|             if message["role"] == MessageRole.SYSTEM: |             if message["role"] == MessageRole.SYSTEM: | ||||||
|  | @ -346,7 +358,9 @@ class AnthropicEngine: | ||||||
|         if system_prompt is None: |         if system_prompt is None: | ||||||
|             raise Exception("No system prompt found!") |             raise Exception("No system prompt found!") | ||||||
| 
 | 
 | ||||||
|         filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message] |         filtered_messages = [ | ||||||
|  |             message for i, message in enumerate(messages) if i != index_system_message | ||||||
|  |         ] | ||||||
|         if len(filtered_messages) == 0: |         if len(filtered_messages) == 0: | ||||||
|             print("Error, no user message:", messages) |             print("Error, no user message:", messages) | ||||||
|             assert False |             assert False | ||||||
|  | @ -366,4 +380,13 @@ class AnthropicEngine: | ||||||
|         return full_response_text |         return full_response_text | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"] | __all__ = [ | ||||||
|  |     "MessageRole", | ||||||
|  |     "llama_role_conversions", | ||||||
|  |     "get_clean_message_list", | ||||||
|  |     "HfEngine", | ||||||
|  |     "TransformersEngine", | ||||||
|  |     "HfApiEngine", | ||||||
|  |     "OpenAIEngine", | ||||||
|  |     "AnthropicEngine", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | @ -1000,4 +1000,5 @@ def evaluate_python_code( | ||||||
|         msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" |         msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" | ||||||
|         raise InterpreterError(msg) |         raise InterpreterError(msg) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["evaluate_python_code"] | __all__ = ["evaluate_python_code"] | ||||||
|  | @ -44,4 +44,5 @@ class Monitor: | ||||||
|             console.print(f"- Input tokens: {self.total_input_token_count:,}") |             console.print(f"- Input tokens: {self.total_input_token_count:,}") | ||||||
|             console.print(f"- Output tokens: {self.total_output_token_count:,}") |             console.print(f"- Output tokens: {self.total_output_token_count:,}") | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["Monitor"] | __all__ = ["Monitor"] | ||||||
|  | @ -491,4 +491,10 @@ Here is my new/updated plan of action to solve the task: | ||||||
| {plan_update} | {plan_update} | ||||||
| ```""" | ```""" | ||||||
| 
 | 
 | ||||||
| __all__ = ["USER_PROMPT_PLAN_UPDATE", "PLAN_UPDATE_FINAL_PLAN_REDACTION", "ONESHOT_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT"] | __all__ = [ | ||||||
|  |     "USER_PROMPT_PLAN_UPDATE", | ||||||
|  |     "PLAN_UPDATE_FINAL_PLAN_REDACTION", | ||||||
|  |     "ONESHOT_CODE_SYSTEM_PROMPT", | ||||||
|  |     "CODE_SYSTEM_PROMPT", | ||||||
|  |     "JSON_SYSTEM_PROMPT", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | @ -78,4 +78,5 @@ class VisitWebpageTool(Tool): | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             return f"An unexpected error occurred: {str(e)}" |             return f"An unexpected error occurred: {str(e)}" | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"] | __all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"] | ||||||
|  | @ -39,12 +39,6 @@ from huggingface_hub import ( | ||||||
| from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session | from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session | ||||||
| from packaging import version | from packaging import version | ||||||
| 
 | 
 | ||||||
| from transformers.dynamic_module_utils import ( |  | ||||||
|     custom_object_save, |  | ||||||
|     get_class_from_dynamic_module, |  | ||||||
|     get_imports, |  | ||||||
| ) |  | ||||||
| from transformers import AutoProcessor |  | ||||||
| from transformers.utils import ( | from transformers.utils import ( | ||||||
|     TypeHintParsingException, |     TypeHintParsingException, | ||||||
|     cached_file, |     cached_file, | ||||||
|  | @ -62,11 +56,10 @@ logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|     import torch |     pass | ||||||
| 
 | 
 | ||||||
| if is_accelerate_available(): | if is_accelerate_available(): | ||||||
|     from accelerate import PartialState |     pass | ||||||
|     from accelerate.utils import send_to_device |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| TOOL_CONFIG_FILE = "tool_config.json" | TOOL_CONFIG_FILE = "tool_config.json" | ||||||
|  | @ -123,6 +116,7 @@ def validate_after_init(cls, do_validate_forward: bool = True): | ||||||
|     cls.__init__ = new_init |     cls.__init__ = new_init | ||||||
|     return cls |     return cls | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def validate_args_are_self_contained(source_code): | def validate_args_are_self_contained(source_code): | ||||||
|     """Validates that all names in forward method are properly defined. |     """Validates that all names in forward method are properly defined. | ||||||
|     In particular it will check that all imports are done within the function.""" |     In particular it will check that all imports are done within the function.""" | ||||||
|  | @ -150,7 +144,7 @@ def validate_args_are_self_contained(source_code): | ||||||
| 
 | 
 | ||||||
|         def visit_ImportFrom(self, node): |         def visit_ImportFrom(self, node): | ||||||
|             """Handle from imports like 'from datetime import datetime'.""" |             """Handle from imports like 'from datetime import datetime'.""" | ||||||
|             module = node.module or '' |             module = node.module or "" | ||||||
|             for name in node.names: |             for name in node.names: | ||||||
|                 actual_name = name.asname or name.name |                 actual_name = name.asname or name.name | ||||||
|                 self.from_imports[actual_name] = (module, name.name, actual_name) |                 self.from_imports[actual_name] = (module, name.name, actual_name) | ||||||
|  | @ -187,9 +181,11 @@ def validate_args_are_self_contained(source_code): | ||||||
|             self.assigned_names.update(target_names) |             self.assigned_names.update(target_names) | ||||||
| 
 | 
 | ||||||
|             # Special handling for enumerate |             # Special handling for enumerate | ||||||
|             if (isinstance(node.iter, ast.Call) and  |             if ( | ||||||
|                 isinstance(node.iter.func, ast.Name) and  |                 isinstance(node.iter, ast.Call) | ||||||
|                 node.iter.func.id == 'enumerate'): |                 and isinstance(node.iter.func, ast.Name) | ||||||
|  |                 and node.iter.func.id == "enumerate" | ||||||
|  |             ): | ||||||
|                 # For enumerate, if we have "for i, x in enumerate(...)", |                 # For enumerate, if we have "for i, x in enumerate(...)", | ||||||
|                 # both i and x should be marked as assigned |                 # both i and x should be marked as assigned | ||||||
|                 if isinstance(node.target, ast.Tuple): |                 if isinstance(node.target, ast.Tuple): | ||||||
|  | @ -201,19 +197,19 @@ def validate_args_are_self_contained(source_code): | ||||||
|             self.generic_visit(node) |             self.generic_visit(node) | ||||||
| 
 | 
 | ||||||
|         def visit_Name(self, node): |         def visit_Name(self, node): | ||||||
|             if (isinstance(node.ctx, ast.Load) and not ( |             if isinstance(node.ctx, ast.Load) and not ( | ||||||
|                 node.id == "tool" or |                 node.id == "tool" | ||||||
|                 node.id in builtin_names or |                 or node.id in builtin_names | ||||||
|                 node.id in arg_names or  |                 or node.id in arg_names | ||||||
|                 node.id == 'self' or |                 or node.id == "self" | ||||||
|                 node.id in self.assigned_names |                 or node.id in self.assigned_names | ||||||
|             )): |             ): | ||||||
|                 if node.id not in self.from_imports and node.id not in self.imports: |                 if node.id not in self.from_imports and node.id not in self.imports: | ||||||
|                     self.undefined_names.add(node.id) |                     self.undefined_names.add(node.id) | ||||||
| 
 | 
 | ||||||
|         def visit_Attribute(self, node): |         def visit_Attribute(self, node): | ||||||
|             # Skip self.something |             # Skip self.something | ||||||
|             if not (isinstance(node.value, ast.Name) and node.value.id == 'self'): |             if not (isinstance(node.value, ast.Name) and node.value.id == "self"): | ||||||
|                 self.generic_visit(node) |                 self.generic_visit(node) | ||||||
| 
 | 
 | ||||||
|     checker = NameChecker() |     checker = NameChecker() | ||||||
|  | @ -226,6 +222,7 @@ def validate_args_are_self_contained(source_code): | ||||||
|             """ |             """ | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| AUTHORIZED_TYPES = [ | AUTHORIZED_TYPES = [ | ||||||
|     "string", |     "string", | ||||||
|     "boolean", |     "boolean", | ||||||
|  | @ -273,7 +270,6 @@ class Tool: | ||||||
|         super().__init_subclass__(**kwargs) |         super().__init_subclass__(**kwargs) | ||||||
|         validate_after_init(cls, do_validate_forward=False) |         validate_after_init(cls, do_validate_forward=False) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     def validate_arguments(self, do_validate_forward: bool = True): |     def validate_arguments(self, do_validate_forward: bool = True): | ||||||
|         required_attributes = { |         required_attributes = { | ||||||
|             "description": str, |             "description": str, | ||||||
|  | @ -359,13 +355,13 @@ class {class_name}(Tool): | ||||||
| 
 | 
 | ||||||
|         def add_self_argument(source_code: str) -> str: |         def add_self_argument(source_code: str) -> str: | ||||||
|             """Add 'self' as first argument to a function definition if not present.""" |             """Add 'self' as first argument to a function definition if not present.""" | ||||||
|             pattern = r'def forward\(((?!self)[^)]*)\)' |             pattern = r"def forward\(((?!self)[^)]*)\)" | ||||||
| 
 | 
 | ||||||
|             def replacement(match): |             def replacement(match): | ||||||
|                 args = match.group(1).strip() |                 args = match.group(1).strip() | ||||||
|                 if args:  # If there are other arguments |                 if args:  # If there are other arguments | ||||||
|                     return f'def forward(self, {args})' |                     return f"def forward(self, {args})" | ||||||
|                 return 'def forward(self)' |                 return "def forward(self)" | ||||||
| 
 | 
 | ||||||
|             return re.sub(pattern, replacement, source_code) |             return re.sub(pattern, replacement, source_code) | ||||||
| 
 | 
 | ||||||
|  | @ -391,11 +387,7 @@ class {class_name}(Tool): | ||||||
|         # Save app file |         # Save app file | ||||||
|         app_file = os.path.join(output_dir, "app.py") |         app_file = os.path.join(output_dir, "app.py") | ||||||
|         with open(app_file, "w", encoding="utf-8") as f: |         with open(app_file, "w", encoding="utf-8") as f: | ||||||
|             f.write( |             f.write(APP_FILE_TEMPLATE.format(class_name=class_name)) | ||||||
|                 APP_FILE_TEMPLATE.format( |  | ||||||
|                     class_name=class_name |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         # Save requirements file |         # Save requirements file | ||||||
|         requirements_file = os.path.join(output_dir, "requirements.txt") |         requirements_file = os.path.join(output_dir, "requirements.txt") | ||||||
|  | @ -457,7 +449,7 @@ class {class_name}(Tool): | ||||||
|             self.save(work_dir) |             self.save(work_dir) | ||||||
|             print(work_dir) |             print(work_dir) | ||||||
|             with open(work_dir + "/tool.py", "r") as f: |             with open(work_dir + "/tool.py", "r") as f: | ||||||
|                 print('\n'.join(f.readlines())) |                 print("\n".join(f.readlines())) | ||||||
|             logger.info( |             logger.info( | ||||||
|                 f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" |                 f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" | ||||||
|             ) |             ) | ||||||
|  | @ -575,7 +567,6 @@ class {class_name}(Tool): | ||||||
| 
 | 
 | ||||||
|         return tool_class(**kwargs) |         return tool_class(**kwargs) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_space( |     def from_space( | ||||||
|         space_id: str, |         space_id: str, | ||||||
|  | @ -702,7 +693,11 @@ class {class_name}(Tool): | ||||||
|                 return output |                 return output | ||||||
| 
 | 
 | ||||||
|         return SpaceToolWrapper( |         return SpaceToolWrapper( | ||||||
|             space_id=space_id, name=name, description=description, api_name=api_name, token=token |             space_id=space_id, | ||||||
|  |             name=name, | ||||||
|  |             description=description, | ||||||
|  |             api_name=api_name, | ||||||
|  |             token=token, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|  | @ -859,7 +854,7 @@ def load_tool( | ||||||
|     model_repo_id: Optional[str] = None, |     model_repo_id: Optional[str] = None, | ||||||
|     token: Optional[str] = None, |     token: Optional[str] = None, | ||||||
|     trust_remote_code: bool = False, |     trust_remote_code: bool = False, | ||||||
|         **kwargs |     **kwargs, | ||||||
| ): | ): | ||||||
|     """ |     """ | ||||||
|     Main function to quickly load a tool, be it on the Hub or in the Transformers library. |     Main function to quickly load a tool, be it on the Hub or in the Transformers library. | ||||||
|  | @ -909,7 +904,11 @@ def load_tool( | ||||||
|             f"code that you have checked." |             f"code that you have checked." | ||||||
|         ) |         ) | ||||||
|         return Tool.from_hub( |         return Tool.from_hub( | ||||||
|             task_or_repo_id, model_repo_id=model_repo_id, token=token, trust_remote_code=trust_remote_code, **kwargs |             task_or_repo_id, | ||||||
|  |             model_repo_id=model_repo_id, | ||||||
|  |             token=token, | ||||||
|  |             trust_remote_code=trust_remote_code, | ||||||
|  |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -1028,7 +1027,7 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|         raise TypeHintParsingException( |         raise TypeHintParsingException( | ||||||
|             "Tool return type not found: make sure your function has a return type hint!" |             "Tool return type not found: make sure your function has a return type hint!" | ||||||
|         ) |         ) | ||||||
|     class_name = ''.join([el.title() for el in parameters['name'].split('_')]) |     class_name = "".join([el.title() for el in parameters["name"].split("_")]) | ||||||
| 
 | 
 | ||||||
|     if parameters["return"]["type"] == "object": |     if parameters["return"]["type"] == "object": | ||||||
|         parameters["return"]["type"] = "any" |         parameters["return"]["type"] = "any" | ||||||
|  | @ -1086,7 +1085,9 @@ class Toolbox: | ||||||
|         """Get all tools currently in the toolbox""" |         """Get all tools currently in the toolbox""" | ||||||
|         return self._tools |         return self._tools | ||||||
| 
 | 
 | ||||||
|     def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str: |     def show_tool_descriptions( | ||||||
|  |         self, tool_description_template: Optional[str] = None | ||||||
|  |     ) -> str: | ||||||
|         """ |         """ | ||||||
|         Returns the description of all tools in the toolbox |         Returns the description of all tools in the toolbox | ||||||
| 
 | 
 | ||||||
|  | @ -1151,4 +1152,12 @@ class Toolbox: | ||||||
|             toolbox_description += f"\t{tool.name}: {tool.description}\n" |             toolbox_description += f"\t{tool.name}: {tool.description}\n" | ||||||
|         return toolbox_description |         return toolbox_description | ||||||
| 
 | 
 | ||||||
| __all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"] | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     "AUTHORIZED_TYPES", | ||||||
|  |     "Tool", | ||||||
|  |     "tool", | ||||||
|  |     "load_tool", | ||||||
|  |     "launch_gradio_demo", | ||||||
|  |     "Toolbox", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | @ -267,4 +267,5 @@ def handle_agent_outputs(output, output_type=None): | ||||||
|                 return _v(output) |                 return _v(output) | ||||||
|         return output |         return output | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] | __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] | ||||||
|  | @ -18,6 +18,7 @@ import json | ||||||
| import re | import re | ||||||
| from typing import Tuple, Dict, Union | from typing import Tuple, Dict, Union | ||||||
| import ast | import ast | ||||||
|  | from rich.console import Console | ||||||
| 
 | 
 | ||||||
| from transformers.utils.import_utils import _is_package_available | from transformers.utils.import_utils import _is_package_available | ||||||
| 
 | 
 | ||||||
|  | @ -28,8 +29,6 @@ def is_pygments_available(): | ||||||
|     return _pygments_available |     return _pygments_available | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| from rich.console import Console |  | ||||||
| 
 |  | ||||||
| console = Console() | console = Console() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -111,6 +110,7 @@ def truncate_content( | ||||||
|             + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] |             + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class ImportFinder(ast.NodeVisitor): | class ImportFinder(ast.NodeVisitor): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.packages = set() |         self.packages = set() | ||||||
|  | @ -118,13 +118,14 @@ class ImportFinder(ast.NodeVisitor): | ||||||
|     def visit_Import(self, node): |     def visit_Import(self, node): | ||||||
|         for alias in node.names: |         for alias in node.names: | ||||||
|             # Get the base package name (before any dots) |             # Get the base package name (before any dots) | ||||||
|             base_package = alias.name.split('.')[0] |             base_package = alias.name.split(".")[0] | ||||||
|             self.packages.add(base_package) |             self.packages.add(base_package) | ||||||
| 
 | 
 | ||||||
|     def visit_ImportFrom(self, node): |     def visit_ImportFrom(self, node): | ||||||
|         if node.module:  # for "from x import y" statements |         if node.module:  # for "from x import y" statements | ||||||
|             # Get the base package name (before any dots) |             # Get the base package name (before any dots) | ||||||
|             base_package = node.module.split('.')[0] |             base_package = node.module.split(".")[0] | ||||||
|             self.packages.add(base_package) |             self.packages.add(base_package) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| __all__ = [] | __all__ = [] | ||||||
|  | @ -27,12 +27,13 @@ from agents.agents import ( | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     JsonAgent, |     JsonAgent, | ||||||
|     Toolbox, |     Toolbox, | ||||||
|     ToolCall |     ToolCall, | ||||||
| ) | ) | ||||||
| from agents.tools import tool | from agents.tools import tool | ||||||
| from agents.default_tools import PythonInterpreterTool | from agents.default_tools import PythonInterpreterTool | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|     directory = tempfile.mkdtemp() |     directory = tempfile.mkdtemp() | ||||||
|     return os.path.join(directory, str(uuid.uuid4()) + suffix) |     return os.path.join(directory, str(uuid.uuid4()) + suffix) | ||||||
|  | @ -60,6 +61,7 @@ Action: | ||||||
| } | } | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str: | def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
| 
 | 
 | ||||||
|  | @ -82,6 +84,7 @@ Action: | ||||||
| } | } | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|  | @ -179,9 +182,7 @@ class AgentTests(unittest.TestCase): | ||||||
|         assert output == "7.2904" |         assert output == "7.2904" | ||||||
| 
 | 
 | ||||||
|     def test_fake_json_agent(self): |     def test_fake_json_agent(self): | ||||||
|         agent = JsonAgent( |         agent = JsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_json_llm) | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_json_llm |  | ||||||
|         ) |  | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, str) |         assert isinstance(output, str) | ||||||
|         assert output == "7.2904" |         assert output == "7.2904" | ||||||
|  | @ -209,9 +210,7 @@ Action: | ||||||
|             Args: |             Args: | ||||||
|                 prompt: The prompt |                 prompt: The prompt | ||||||
|             """ |             """ | ||||||
|             return Image.open( |             return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png") | ||||||
|                 Path(get_tests_dir("fixtures")) / "000000039769.png" |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         agent = JsonAgent( |         agent = JsonAgent( | ||||||
|             tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image |             tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image | ||||||
|  | @ -221,9 +220,7 @@ Action: | ||||||
|         assert isinstance(agent.state["image.png"], Image.Image) |         assert isinstance(agent.state["image.png"], Image.Image) | ||||||
| 
 | 
 | ||||||
|     def test_fake_code_agent(self): |     def test_fake_code_agent(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm |  | ||||||
|         ) |  | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, float) |         assert isinstance(output, float) | ||||||
|         assert output == 7.2904 |         assert output == 7.2904 | ||||||
|  | @ -234,9 +231,7 @@ Action: | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def test_reset_conversations(self): |     def test_reset_conversations(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm |  | ||||||
|         ) |  | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?", reset=True) |         output = agent.run("What is 2 multiplied by 3.6452?", reset=True) | ||||||
|         assert output == 7.2904 |         assert output == 7.2904 | ||||||
|         assert len(agent.logs) == 4 |         assert len(agent.logs) == 4 | ||||||
|  | @ -299,9 +294,7 @@ Action: | ||||||
| 
 | 
 | ||||||
|         # check that python_interpreter base tool does not get added to code agents |         # check that python_interpreter base tool does not get added to code agents | ||||||
|         agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True) |         agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True) | ||||||
|         assert ( |         assert len(agent.toolbox.tools) == 2  # added final_answer tool + search | ||||||
|             len(agent.toolbox.tools) == 2 |  | ||||||
|         )  # added final_answer tool + search |  | ||||||
| 
 | 
 | ||||||
|     def test_function_persistence_across_steps(self): |     def test_function_persistence_across_steps(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|  |  | ||||||
|  | @ -25,6 +25,7 @@ from pathlib import Path | ||||||
| from typing import List | from typing import List | ||||||
| from dotenv import load_dotenv | from dotenv import load_dotenv | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class SubprocessCallException(Exception): | class SubprocessCallException(Exception): | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
|  | @ -59,7 +60,7 @@ class DocCodeExtractor: | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def extract_python_code(content: str) -> List[str]: |     def extract_python_code(content: str) -> List[str]: | ||||||
|         """Extract Python code blocks from markdown content.""" |         """Extract Python code blocks from markdown content.""" | ||||||
|         pattern = r'```(?:python|py)\n(.*?)\n```' |         pattern = r"```(?:python|py)\n(.*?)\n```" | ||||||
|         matches = re.finditer(pattern, content, re.DOTALL) |         matches = re.finditer(pattern, content, re.DOTALL) | ||||||
|         return [match.group(1).strip() for match in matches] |         return [match.group(1).strip() for match in matches] | ||||||
| 
 | 
 | ||||||
|  | @ -118,18 +119,27 @@ class TestDocs: | ||||||
| 
 | 
 | ||||||
|         # Create and execute test script |         # Create and execute test script | ||||||
|         try: |         try: | ||||||
|             excluded_snippets = ["ToolCollection", "image_generation_tool", "from_langchain"] |             excluded_snippets = [ | ||||||
|  |                 "ToolCollection", | ||||||
|  |                 "image_generation_tool", | ||||||
|  |                 "from_langchain", | ||||||
|  |             ] | ||||||
|             code_blocks = [ |             code_blocks = [ | ||||||
|                 block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token) for block in code_blocks |                 block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token) | ||||||
|                 if not any([snippet in block for snippet in excluded_snippets]) # Exclude these tools that take longer to run and add dependencies |                 for block in code_blocks | ||||||
|  |                 if not any( | ||||||
|  |                     [snippet in block for snippet in excluded_snippets] | ||||||
|  |                 )  # Exclude these tools that take longer to run and add dependencies | ||||||
|             ] |             ] | ||||||
|             test_script = self.extractor.create_test_script(code_blocks, self._tmpdir) |             test_script = self.extractor.create_test_script(code_blocks, self._tmpdir) | ||||||
|             run_command(self.launch_args + [str(test_script)]) |             run_command(self.launch_args + [str(test_script)]) | ||||||
| 
 | 
 | ||||||
|         except SubprocessCallException as e: |         except SubprocessCallException as e: | ||||||
|             pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}") |             pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}") | ||||||
|         except Exception as e: |         except Exception: | ||||||
|             pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}") |             pytest.fail( | ||||||
|  |                 f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     @pytest.fixture(autouse=True) |     @pytest.fixture(autouse=True) | ||||||
|     def _setup(self): |     def _setup(self): | ||||||
|  | @ -152,7 +162,5 @@ def pytest_generate_tests(metafunc): | ||||||
| 
 | 
 | ||||||
|         # Parameterize with the markdown files |         # Parameterize with the markdown files | ||||||
|         metafunc.parametrize( |         metafunc.parametrize( | ||||||
|             "doc_path", |             "doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files] | ||||||
|             test_class.md_files, |  | ||||||
|             ids=[f.stem for f in test_class.md_files] |  | ||||||
|         ) |         ) | ||||||
|  | @ -20,9 +20,10 @@ import numpy as np | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 | 
 | ||||||
| from transformers import is_torch_available | from transformers import is_torch_available | ||||||
| from agents.types import AGENT_TYPE_MAPPING |  | ||||||
| from agents.default_tools import FinalAnswerTool |  | ||||||
| from transformers.testing_utils import get_tests_dir, require_torch | from transformers.testing_utils import get_tests_dir, require_torch | ||||||
|  | from agents.types import AGENT_TYPE_MAPPING | ||||||
|  | 
 | ||||||
|  | from agents.default_tools import FinalAnswerTool | ||||||
| 
 | 
 | ||||||
| from .test_tools_common import ToolTesterMixin | from .test_tools_common import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -98,6 +98,7 @@ class ToolTesterMixin: | ||||||
|             agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] |             agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|             self.assertTrue(isinstance(output, agent_type)) |             self.assertTrue(isinstance(output, agent_type)) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class ToolTests(unittest.TestCase): | class ToolTests(unittest.TestCase): | ||||||
|     def test_tool_init_with_decorator(self): |     def test_tool_init_with_decorator(self): | ||||||
|         @tool |         @tool | ||||||
|  | @ -163,19 +164,22 @@ class ToolTests(unittest.TestCase): | ||||||
|             assert coolfunc.output_type == "number" |             assert coolfunc.output_type == "number" | ||||||
|         assert "docstring has no description for the argument" in str(e) |         assert "docstring has no description for the argument" in str(e) | ||||||
| 
 | 
 | ||||||
|     def test_tool_definition_needs_imports_in_function(self): |     def test_tool_definition_raises_error_imports_outside_function(self): | ||||||
|         with pytest.raises(Exception) as e: |         with pytest.raises(Exception) as e: | ||||||
|             from datetime import datetime |             from datetime import datetime | ||||||
|  | 
 | ||||||
|             @tool |             @tool | ||||||
|             def get_current_time() -> str: |             def get_current_time() -> str: | ||||||
|                 """ |                 """ | ||||||
|                 Gets the current time. |                 Gets the current time. | ||||||
|                 """ |                 """ | ||||||
|                 return str(datetime.now()) |                 return str(datetime.now()) | ||||||
|  | 
 | ||||||
|         assert "datetime" in str(e) |         assert "datetime" in str(e) | ||||||
| 
 | 
 | ||||||
|         # Also test with classic definition |         # Also test with classic definition | ||||||
|         with pytest.raises(Exception) as e: |         with pytest.raises(Exception) as e: | ||||||
|  | 
 | ||||||
|             class GetCurrentTimeTool(Tool): |             class GetCurrentTimeTool(Tool): | ||||||
|                 name = "get_current_time_tool" |                 name = "get_current_time_tool" | ||||||
|                 description = "Gets the current time" |                 description = "Gets the current time" | ||||||
|  | @ -184,14 +188,17 @@ class ToolTests(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|                 def forward(self): |                 def forward(self): | ||||||
|                     return str(datetime.now()) |                     return str(datetime.now()) | ||||||
|  | 
 | ||||||
|         assert "datetime" in str(e) |         assert "datetime" in str(e) | ||||||
| 
 | 
 | ||||||
|  |     def test_tool_definition_raises_no_error_imports_in_function(self): | ||||||
|         @tool |         @tool | ||||||
|         def get_current_time() -> str: |         def get_current_time() -> str: | ||||||
|             """ |             """ | ||||||
|             Gets the current time. |             Gets the current time. | ||||||
|             """ |             """ | ||||||
|             from datetime import datetime |             from datetime import datetime | ||||||
|  | 
 | ||||||
|             return str(datetime.now()) |             return str(datetime.now()) | ||||||
| 
 | 
 | ||||||
|         class GetCurrentTimeTool(Tool): |         class GetCurrentTimeTool(Tool): | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import tempfile | ||||||
| 
 | 
 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def str_to_bool(value) -> int: | def str_to_bool(value) -> int: | ||||||
|     """ |     """ | ||||||
|     Converts a string representation of truth to `True` (1) or `False` (0). |     Converts a string representation of truth to `True` (1) or `False` (0). | ||||||
|  | @ -28,10 +29,13 @@ def get_int_from_env(env_keys, default): | ||||||
|             return val |             return val | ||||||
|     return default |     return default | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def parse_flag_from_env(key, default=False): | def parse_flag_from_env(key, default=False): | ||||||
|     """Returns truthy value for `key` from the env if available else the default.""" |     """Returns truthy value for `key` from the env if available else the default.""" | ||||||
|     value = os.environ.get(key, str(default)) |     value = os.environ.get(key, str(default)) | ||||||
|     return str_to_bool(value) == 1  # As its name indicates `str_to_bool` actually returns an int... |     return ( | ||||||
|  |         str_to_bool(value) == 1 | ||||||
|  |     )  # As its name indicates `str_to_bool` actually returns an int... | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue