Start fixing state transfer between local and docker executor
This commit is contained in:
		
							parent
							
								
									8e758fa130
								
							
						
					
					
						commit
						8ed03634b0
					
				|  | @ -47,7 +47,7 @@ def visit_webpage(url: str) -> str: | ||||||
| 
 | 
 | ||||||
| llm_engine = HfApiEngine(model) | llm_engine = HfApiEngine(model) | ||||||
| 
 | 
 | ||||||
| web_agent = JsonAgent( | web_agent = CodeAgent( | ||||||
|     tools=[DuckDuckGoSearchTool(), visit_webpage], |     tools=[DuckDuckGoSearchTool(), visit_webpage], | ||||||
|     llm_engine=llm_engine, |     llm_engine=llm_engine, | ||||||
|     max_iterations=10, |     max_iterations=10, | ||||||
|  |  | ||||||
|  | @ -38,7 +38,7 @@ from .prompts import ( | ||||||
|     SYSTEM_PROMPT_PLAN_UPDATE, |     SYSTEM_PROMPT_PLAN_UPDATE, | ||||||
|     SYSTEM_PROMPT_PLAN, |     SYSTEM_PROMPT_PLAN, | ||||||
| ) | ) | ||||||
| from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code | ||||||
| from .tools import ( | from .tools import ( | ||||||
|     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, |     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, | ||||||
|     Tool, |     Tool, | ||||||
|  |  | ||||||
|  | @ -23,7 +23,7 @@ from typing import Dict | ||||||
| from huggingface_hub import hf_hub_download, list_spaces | from huggingface_hub import hf_hub_download, list_spaces | ||||||
| 
 | 
 | ||||||
| from transformers.utils import is_offline_mode | from transformers.utils import is_offline_mode | ||||||
| from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code | ||||||
| from .tools import TOOL_CONFIG_FILE, Tool | from .tools import TOOL_CONFIG_FILE, Tool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,13 +1,80 @@ | ||||||
| import sys |  | ||||||
| import json | import json | ||||||
| import traceback |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import docker | import docker | ||||||
| import time | import time | ||||||
| import uuid | import uuid | ||||||
| import signal | import pickle | ||||||
| from typing import Optional, Dict, Tuple, Any | import re | ||||||
| import subprocess | from typing import Optional, Dict, Tuple, Set, Any | ||||||
|  | import types | ||||||
|  | from .default_tools import BASE_PYTHON_TOOLS | ||||||
|  | 
 | ||||||
|  | class StateManager: | ||||||
|  |     def __init__(self, work_dir: Path): | ||||||
|  |         self.work_dir = work_dir | ||||||
|  |         self.state_file = work_dir / "interpreter_state.pickle" | ||||||
|  |         self.imports_file = work_dir / "imports.txt" | ||||||
|  |         self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$') | ||||||
|  |         self.imports: Set[str] = set() | ||||||
|  |          | ||||||
|  |     def is_import_statement(self, code: str) -> bool: | ||||||
|  |         """Check if a line of code is an import statement.""" | ||||||
|  |         return bool(self.import_pattern.match(code.strip())) | ||||||
|  |      | ||||||
|  |     def track_imports(self, code: str): | ||||||
|  |         """Track import statements for later use.""" | ||||||
|  |         for line in code.split('\n'): | ||||||
|  |             if self.is_import_statement(line.strip()): | ||||||
|  |                 self.imports.add(line.strip()) | ||||||
|  | 
 | ||||||
|  |     def save_state(self, locals_dict: Dict[str, Any], executor: str): | ||||||
|  |         """ | ||||||
|  |         Save the current state of variables and imports. | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             locals_dict: Dictionary of local variables | ||||||
|  |             executor: 'docker' or 'local' to indicate source | ||||||
|  |         """ | ||||||
|  |         # Filter out modules, functions, and special variables | ||||||
|  |         state_dict = { | ||||||
|  |             'variables': { | ||||||
|  |                 k: v for k, v in locals_dict.items() | ||||||
|  |                 if not ( | ||||||
|  |                     k.startswith('_') | ||||||
|  |                     or callable(v) | ||||||
|  |                     or isinstance(v, type) | ||||||
|  |                     or isinstance(v, types.ModuleType) | ||||||
|  |                 ) | ||||||
|  |             }, | ||||||
|  |             'imports': list(self.imports), | ||||||
|  |             'source': executor | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         with open(self.state_file, 'wb') as f: | ||||||
|  |             pickle.dump(state_dict, f) | ||||||
|  |              | ||||||
|  |     def load_state(self, executor: str) -> Dict[str, Any]: | ||||||
|  |         """ | ||||||
|  |         Load the saved state and handle imports. | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             executor: 'docker' or 'local' to indicate destination | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             Dictionary of variables to restore | ||||||
|  |         """ | ||||||
|  |         if not self.state_file.exists(): | ||||||
|  |             return {} | ||||||
|  |              | ||||||
|  |         with open(self.state_file, 'rb') as f: | ||||||
|  |             state_dict = pickle.load(f) | ||||||
|  |              | ||||||
|  |         # First handle imports | ||||||
|  |         for import_stmt in state_dict['imports']: | ||||||
|  |             exec(import_stmt, globals()) | ||||||
|  |              | ||||||
|  |         return state_dict['variables'] | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def read_multiplexed_response(socket): | def read_multiplexed_response(socket): | ||||||
|     """Read and demultiplex all responses from Docker exec socket""" |     """Read and demultiplex all responses from Docker exec socket""" | ||||||
|  | @ -20,6 +87,7 @@ def read_multiplexed_response(socket): | ||||||
|         responses = response_data.split(b'\x01\x00\x00\x00\x00\x00') |         responses = response_data.split(b'\x01\x00\x00\x00\x00\x00') | ||||||
| 
 | 
 | ||||||
|         # The last non-empty chunk should be our JSON response |         # The last non-empty chunk should be our JSON response | ||||||
|  |         if len(responses) > 0: | ||||||
|             for chunk in reversed(responses): |             for chunk in reversed(responses): | ||||||
|                 if chunk and len(chunk.strip()) > 0: |                 if chunk and len(chunk.strip()) > 0: | ||||||
|                     try: |                     try: | ||||||
|  | @ -35,7 +103,7 @@ def read_multiplexed_response(socket): | ||||||
|         i+=1 |         i+=1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DockerInterpreter: | class DockerPythonInterpreter: | ||||||
|     def __init__(self, work_dir: Path = Path(".")): |     def __init__(self, work_dir: Path = Path(".")): | ||||||
|         self.client = docker.from_env() |         self.client = docker.from_env() | ||||||
|         self.work_dir = work_dir |         self.work_dir = work_dir | ||||||
|  | @ -43,6 +111,8 @@ class DockerInterpreter: | ||||||
|         self.container = None |         self.container = None | ||||||
|         self.exec_id = None |         self.exec_id = None | ||||||
|         self.socket = None |         self.socket = None | ||||||
|  |         self.state_manager = StateManager(work_dir) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|     def create_interpreter_script(self) -> str: |     def create_interpreter_script(self) -> str: | ||||||
|         """Create the interpreter script that will run inside the container""" |         """Create the interpreter script that will run inside the container""" | ||||||
|  | @ -52,7 +122,9 @@ import code | ||||||
| import json | import json | ||||||
| import traceback | import traceback | ||||||
| import signal | import signal | ||||||
|  | import types | ||||||
| from threading import Lock | from threading import Lock | ||||||
|  | import pickle | ||||||
| 
 | 
 | ||||||
| class PersistentInterpreter(code.InteractiveInterpreter): | class PersistentInterpreter(code.InteractiveInterpreter): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|  | @ -67,6 +139,12 @@ class PersistentInterpreter(code.InteractiveInterpreter): | ||||||
|     def run_command(self, source): |     def run_command(self, source): | ||||||
|         with self.lock: |         with self.lock: | ||||||
|             self.output_buffer = [] |             self.output_buffer = [] | ||||||
|  |             pickle_path = self.work_dir / "locals.pickle" | ||||||
|  |             if pickle_path.exists(): | ||||||
|  |                 with open(pickle_path, 'rb') as f: | ||||||
|  |                     locals_dict_update = pickle.load(f)['variables'] | ||||||
|  |             self.locals_dict.update(locals_dict_update) | ||||||
|  | 
 | ||||||
|             try: |             try: | ||||||
|                 more = self.runsource(source) |                 more = self.runsource(source) | ||||||
|                 output = ''.join(self.output_buffer) |                 output = ''.join(self.output_buffer) | ||||||
|  | @ -78,11 +156,25 @@ class PersistentInterpreter(code.InteractiveInterpreter): | ||||||
|                             output = repr(result) + '\\n' |                             output = repr(result) + '\\n' | ||||||
|                     except: |                     except: | ||||||
|                         pass |                         pass | ||||||
|                 return json.dumps({'output': output, 'more': more, 'error': None}) + '\\n' |                 output = json.dumps({'output': output, 'more': more, 'error': None}) + '\\n' | ||||||
|             except KeyboardInterrupt: |             except KeyboardInterrupt: | ||||||
|                 return json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n' |                 output = json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n' | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 return json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n' |                 output = json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n' | ||||||
|  |             finally: | ||||||
|  |                 with open('/workspace/locals.pickle', 'wb') as f: | ||||||
|  |                     filtered_locals = { | ||||||
|  |                         k: v for k, v in self.locals_dict.items()  | ||||||
|  |                         if not ( | ||||||
|  |                             k.startswith('_') | ||||||
|  |                             or k in {'pickle', 'f'} | ||||||
|  |                             or callable(v) | ||||||
|  |                             or isinstance(v, type) | ||||||
|  |                             or isinstance(v, types.ModuleType) | ||||||
|  |                         ) | ||||||
|  |                     } | ||||||
|  |                     pickle.dump(filtered_locals, f) | ||||||
|  |             return output | ||||||
| 
 | 
 | ||||||
| def main(): | def main(): | ||||||
|     interpreter = PersistentInterpreter() |     interpreter = PersistentInterpreter() | ||||||
|  | @ -135,6 +227,8 @@ if __name__ == '__main__': | ||||||
|         if container_name is None: |         if container_name is None: | ||||||
|             container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}" |             container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}" | ||||||
| 
 | 
 | ||||||
|  |         self.create_interpreter_script() | ||||||
|  |          | ||||||
|         # Setup volume mapping |         # Setup volume mapping | ||||||
|         volumes = { |         volumes = { | ||||||
|             str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"} |             str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"} | ||||||
|  | @ -160,7 +254,7 @@ if __name__ == '__main__': | ||||||
|             ) |             ) | ||||||
|             # Install packages in the new container |             # Install packages in the new container | ||||||
|             print("Installing packages...") |             print("Installing packages...") | ||||||
|             packages = ["pandas", "numpy"]  # Add your required packages here |             packages = ["pandas", "numpy", "pickle5"]  # Add your required packages here | ||||||
| 
 | 
 | ||||||
|             result = self.container.exec_run( |             result = self.container.exec_run( | ||||||
|                 f"pip install {' '.join(packages)}", |                 f"pip install {' '.join(packages)}", | ||||||
|  | @ -192,7 +286,11 @@ if __name__ == '__main__': | ||||||
|             demux=True |             demux=True | ||||||
|         )._sock |         )._sock | ||||||
| 
 | 
 | ||||||
|     def execute(self, code: str) -> Tuple[str, bool]: |     def _raw_execute(self, code: str) -> Tuple[str, bool]: | ||||||
|  |         """ | ||||||
|  |         Execute code directly without state management. | ||||||
|  |         This is the original execute method functionality. | ||||||
|  |         """ | ||||||
|         if not self.container: |         if not self.container: | ||||||
|             raise Exception("Container not started") |             raise Exception("Container not started") | ||||||
|         if not self.socket: |         if not self.socket: | ||||||
|  | @ -209,6 +307,30 @@ if __name__ == '__main__': | ||||||
|         except json.JSONDecodeError: |         except json.JSONDecodeError: | ||||||
|             return f"Error: Invalid response from interpreter: {response}", False |             return f"Error: Invalid response from interpreter: {response}", False | ||||||
| 
 | 
 | ||||||
|  |     def get_locals_dict(self) -> Dict[str, Any]: | ||||||
|  |         """Get the current locals dictionary from the interpreter by pickling directly from Docker.""" | ||||||
|  |         pickle_path = self.work_dir / "locals.pickle" | ||||||
|  |         if pickle_path.exists(): | ||||||
|  |             with open(pickle_path, 'rb') as f: | ||||||
|  |                 try: | ||||||
|  |                     return pickle.load(f) | ||||||
|  |                 except Exception as e: | ||||||
|  |                     print(f"Error loading pickled locals: {e}") | ||||||
|  |                     return {} | ||||||
|  |         return {} | ||||||
|  | 
 | ||||||
|  |     def execute(self, code: str) -> Tuple[str, bool]: | ||||||
|  |         # Track imports before execution | ||||||
|  |         self.state_manager.track_imports(code) | ||||||
|  |          | ||||||
|  |         output, more = self._raw_execute(code) | ||||||
|  |          | ||||||
|  |         # Save state after execution | ||||||
|  |         self.state_manager.save_state( | ||||||
|  |             self.get_locals_dict(), | ||||||
|  |             'docker' | ||||||
|  |         ) | ||||||
|  |         return output, more | ||||||
| 
 | 
 | ||||||
|     def stop(self, remove: bool = False): |     def stop(self, remove: bool = False): | ||||||
|         if self.socket: |         if self.socket: | ||||||
|  | @ -227,33 +349,71 @@ if __name__ == '__main__': | ||||||
|                 print(f"Error stopping container: {e}") |                 print(f"Error stopping container: {e}") | ||||||
|                 raise |                 raise | ||||||
| 
 | 
 | ||||||
| def main(): | def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: | ||||||
|     work_dir = Path("interpreter_workspace") |     from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES | ||||||
|     interpreter = DockerInterpreter(work_dir) |  | ||||||
| 
 | 
 | ||||||
|     def signal_handler(signum, frame): |     """Execute code locally with state transfer.""" | ||||||
|         print("\nExiting...") |     state_manager = StateManager(work_dir) | ||||||
|         interpreter.stop(remove=True) |  | ||||||
|         sys.exit(0) |  | ||||||
|      |      | ||||||
|     signal.signal(signal.SIGINT, signal_handler) |     # Track imports | ||||||
|  |     state_manager.track_imports(code) | ||||||
|      |      | ||||||
|     print("Starting Python interpreter in Docker...") |     # Load state from Docker if available | ||||||
|     interpreter.start("persistent_python_interpreter2") |     locals_dict = state_manager.load_state('local') | ||||||
| 
 | 
 | ||||||
|     snippet = "import pandas as pd" |     # Execute in a new namespace with loaded state | ||||||
|     output, more = interpreter.execute(snippet) |     namespace = {} | ||||||
|     print("OUTPUT1") |     namespace.update(locals_dict) | ||||||
|  |      | ||||||
|  |     output = evaluate_python_code( | ||||||
|  |         code, | ||||||
|  |         tools, | ||||||
|  |         {}, | ||||||
|  |         namespace, | ||||||
|  |         LIST_SAFE_MODULES, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Save state for Docker | ||||||
|  |     state_manager.save_state(namespace, 'local') | ||||||
|  |     return output | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def create_tools_regex(tool_names): | ||||||
|  |     # Escape any special regex characters in tool names | ||||||
|  |     escaped_names = [re.escape(name) for name in tool_names] | ||||||
|  |     # Join with | and add word boundaries | ||||||
|  |     pattern = r'\b(' + '|'.join(escaped_names) + r')\b' | ||||||
|  |     return re.compile(pattern) | ||||||
|  | 
 | ||||||
|  | def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter): | ||||||
|  |     """Execute code with automatic switching between Docker and local.""" | ||||||
|  |     lines = code.split('\n') | ||||||
|  |     current_block = [] | ||||||
|  |     tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing | ||||||
|  | 
 | ||||||
|  |     tools = { | ||||||
|  |         **BASE_PYTHON_TOOLS.copy(), | ||||||
|  |         **tools, | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     for line in lines: | ||||||
|  |         if tool_regex.search(line): | ||||||
|  |             # Execute accumulated Docker code if any | ||||||
|  |             if current_block: | ||||||
|  |                 output, more = interpreter.execute('\n'.join(current_block)) | ||||||
|                 print(output, end='') |                 print(output, end='') | ||||||
|  |                 current_block = [] | ||||||
|              |              | ||||||
|     snippet = "pd.DataFrame()" |             output = execute_locally(line, work_dir, tools) | ||||||
|     output, more = interpreter.execute(snippet) |             if output: | ||||||
|     print("OUTPUT2") |                 print(output, end='') | ||||||
|  |         else: | ||||||
|  |             current_block.append(line) | ||||||
|  |      | ||||||
|  |     # Execute any remaining Docker code | ||||||
|  |     if current_block: | ||||||
|  |         output, more = interpreter.execute('\n'.join(current_block)) | ||||||
|         print(output, end='') |         print(output, end='') | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     print("\nStopping interpreter...") | __all__ = ["DockerPythonInterpreter", "execute_code"] | ||||||
|     interpreter.stop(remove=True) |  | ||||||
| 
 |  | ||||||
| if __name__ == '__main__': |  | ||||||
|     main() |  | ||||||
		Loading…
	
		Reference in New Issue