Start fixing state transfer between local and docker executor

This commit is contained in:
Aymeric 2024-12-13 13:58:26 +01:00
parent 8e758fa130
commit 8ed03634b0
4 changed files with 210 additions and 50 deletions

View File

@ -47,7 +47,7 @@ def visit_webpage(url: str) -> str:
llm_engine = HfApiEngine(model) llm_engine = HfApiEngine(model)
web_agent = JsonAgent( web_agent = CodeAgent(
tools=[DuckDuckGoSearchTool(), visit_webpage], tools=[DuckDuckGoSearchTool(), visit_webpage],
llm_engine=llm_engine, llm_engine=llm_engine,
max_iterations=10, max_iterations=10,

View File

@ -38,7 +38,7 @@ from .prompts import (
SYSTEM_PROMPT_PLAN_UPDATE, SYSTEM_PROMPT_PLAN_UPDATE,
SYSTEM_PROMPT_PLAN, SYSTEM_PROMPT_PLAN,
) )
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import ( from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE, DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool, Tool,

View File

@ -23,7 +23,7 @@ from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode from transformers.utils import is_offline_mode
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TOOL_CONFIG_FILE, Tool from .tools import TOOL_CONFIG_FILE, Tool

View File

@ -1,13 +1,80 @@
import sys
import json import json
import traceback
from pathlib import Path from pathlib import Path
import docker import docker
import time import time
import uuid import uuid
import signal import pickle
from typing import Optional, Dict, Tuple, Any import re
import subprocess from typing import Optional, Dict, Tuple, Set, Any
import types
from .default_tools import BASE_PYTHON_TOOLS
class StateManager:
def __init__(self, work_dir: Path):
self.work_dir = work_dir
self.state_file = work_dir / "interpreter_state.pickle"
self.imports_file = work_dir / "imports.txt"
self.import_pattern = re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$')
self.imports: Set[str] = set()
def is_import_statement(self, code: str) -> bool:
"""Check if a line of code is an import statement."""
return bool(self.import_pattern.match(code.strip()))
def track_imports(self, code: str):
"""Track import statements for later use."""
for line in code.split('\n'):
if self.is_import_statement(line.strip()):
self.imports.add(line.strip())
def save_state(self, locals_dict: Dict[str, Any], executor: str):
"""
Save the current state of variables and imports.
Args:
locals_dict: Dictionary of local variables
executor: 'docker' or 'local' to indicate source
"""
# Filter out modules, functions, and special variables
state_dict = {
'variables': {
k: v for k, v in locals_dict.items()
if not (
k.startswith('_')
or callable(v)
or isinstance(v, type)
or isinstance(v, types.ModuleType)
)
},
'imports': list(self.imports),
'source': executor
}
with open(self.state_file, 'wb') as f:
pickle.dump(state_dict, f)
def load_state(self, executor: str) -> Dict[str, Any]:
"""
Load the saved state and handle imports.
Args:
executor: 'docker' or 'local' to indicate destination
Returns:
Dictionary of variables to restore
"""
if not self.state_file.exists():
return {}
with open(self.state_file, 'rb') as f:
state_dict = pickle.load(f)
# First handle imports
for import_stmt in state_dict['imports']:
exec(import_stmt, globals())
return state_dict['variables']
def read_multiplexed_response(socket): def read_multiplexed_response(socket):
"""Read and demultiplex all responses from Docker exec socket""" """Read and demultiplex all responses from Docker exec socket"""
@ -20,22 +87,23 @@ def read_multiplexed_response(socket):
responses = response_data.split(b'\x01\x00\x00\x00\x00\x00') responses = response_data.split(b'\x01\x00\x00\x00\x00\x00')
# The last non-empty chunk should be our JSON response # The last non-empty chunk should be our JSON response
for chunk in reversed(responses): if len(responses) > 0:
if chunk and len(chunk.strip()) > 0: for chunk in reversed(responses):
try: if chunk and len(chunk.strip()) > 0:
# Find the start of valid JSON by looking for '{' try:
json_start = chunk.find(b'{') # Find the start of valid JSON by looking for '{'
if json_start != -1: json_start = chunk.find(b'{')
decoded = chunk[json_start:].decode('utf-8') if json_start != -1:
result = json.loads(decoded) decoded = chunk[json_start:].decode('utf-8')
if "output" in result: result = json.loads(decoded)
return decoded if "output" in result:
except json.JSONDecodeError: return decoded
continue except json.JSONDecodeError:
continue
i+=1 i+=1
class DockerInterpreter: class DockerPythonInterpreter:
def __init__(self, work_dir: Path = Path(".")): def __init__(self, work_dir: Path = Path(".")):
self.client = docker.from_env() self.client = docker.from_env()
self.work_dir = work_dir self.work_dir = work_dir
@ -43,6 +111,8 @@ class DockerInterpreter:
self.container = None self.container = None
self.exec_id = None self.exec_id = None
self.socket = None self.socket = None
self.state_manager = StateManager(work_dir)
def create_interpreter_script(self) -> str: def create_interpreter_script(self) -> str:
"""Create the interpreter script that will run inside the container""" """Create the interpreter script that will run inside the container"""
@ -52,7 +122,9 @@ import code
import json import json
import traceback import traceback
import signal import signal
import types
from threading import Lock from threading import Lock
import pickle
class PersistentInterpreter(code.InteractiveInterpreter): class PersistentInterpreter(code.InteractiveInterpreter):
def __init__(self): def __init__(self):
@ -67,6 +139,12 @@ class PersistentInterpreter(code.InteractiveInterpreter):
def run_command(self, source): def run_command(self, source):
with self.lock: with self.lock:
self.output_buffer = [] self.output_buffer = []
pickle_path = self.work_dir / "locals.pickle"
if pickle_path.exists():
with open(pickle_path, 'rb') as f:
locals_dict_update = pickle.load(f)['variables']
self.locals_dict.update(locals_dict_update)
try: try:
more = self.runsource(source) more = self.runsource(source)
output = ''.join(self.output_buffer) output = ''.join(self.output_buffer)
@ -78,11 +156,25 @@ class PersistentInterpreter(code.InteractiveInterpreter):
output = repr(result) + '\\n' output = repr(result) + '\\n'
except: except:
pass pass
return json.dumps({'output': output, 'more': more, 'error': None}) + '\\n' output = json.dumps({'output': output, 'more': more, 'error': None}) + '\\n'
except KeyboardInterrupt: except KeyboardInterrupt:
return json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n' output = json.dumps({'output': '\\nKeyboardInterrupt\\n', 'more': False, 'error': 'interrupt'}) + '\\n'
except Exception as e: except Exception as e:
return json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n' output = json.dumps({'output': f"Error: {str(e)}\\n", 'more': False, 'error': str(e)}) + '\\n'
finally:
with open('/workspace/locals.pickle', 'wb') as f:
filtered_locals = {
k: v for k, v in self.locals_dict.items()
if not (
k.startswith('_')
or k in {'pickle', 'f'}
or callable(v)
or isinstance(v, type)
or isinstance(v, types.ModuleType)
)
}
pickle.dump(filtered_locals, f)
return output
def main(): def main():
interpreter = PersistentInterpreter() interpreter = PersistentInterpreter()
@ -135,6 +227,8 @@ if __name__ == '__main__':
if container_name is None: if container_name is None:
container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}" container_name = f"python-interpreter-{uuid.uuid4().hex[:8]}"
self.create_interpreter_script()
# Setup volume mapping # Setup volume mapping
volumes = { volumes = {
str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"} str(self.work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}
@ -160,7 +254,7 @@ if __name__ == '__main__':
) )
# Install packages in the new container # Install packages in the new container
print("Installing packages...") print("Installing packages...")
packages = ["pandas", "numpy"] # Add your required packages here packages = ["pandas", "numpy", "pickle5"] # Add your required packages here
result = self.container.exec_run( result = self.container.exec_run(
f"pip install {' '.join(packages)}", f"pip install {' '.join(packages)}",
@ -192,8 +286,12 @@ if __name__ == '__main__':
demux=True demux=True
)._sock )._sock
def execute(self, code: str) -> Tuple[str, bool]: def _raw_execute(self, code: str) -> Tuple[str, bool]:
if not self.container : """
Execute code directly without state management.
This is the original execute method functionality.
"""
if not self.container:
raise Exception("Container not started") raise Exception("Container not started")
if not self.socket: if not self.socket:
raise Exception("Socket not started") raise Exception("Socket not started")
@ -209,6 +307,30 @@ if __name__ == '__main__':
except json.JSONDecodeError: except json.JSONDecodeError:
return f"Error: Invalid response from interpreter: {response}", False return f"Error: Invalid response from interpreter: {response}", False
def get_locals_dict(self) -> Dict[str, Any]:
"""Get the current locals dictionary from the interpreter by pickling directly from Docker."""
pickle_path = self.work_dir / "locals.pickle"
if pickle_path.exists():
with open(pickle_path, 'rb') as f:
try:
return pickle.load(f)
except Exception as e:
print(f"Error loading pickled locals: {e}")
return {}
return {}
def execute(self, code: str) -> Tuple[str, bool]:
# Track imports before execution
self.state_manager.track_imports(code)
output, more = self._raw_execute(code)
# Save state after execution
self.state_manager.save_state(
self.get_locals_dict(),
'docker'
)
return output, more
def stop(self, remove: bool = False): def stop(self, remove: bool = False):
if self.socket: if self.socket:
@ -227,33 +349,71 @@ if __name__ == '__main__':
print(f"Error stopping container: {e}") print(f"Error stopping container: {e}")
raise raise
def main(): def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
work_dir = Path("interpreter_workspace") from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
interpreter = DockerInterpreter(work_dir)
def signal_handler(signum, frame): """Execute code locally with state transfer."""
print("\nExiting...") state_manager = StateManager(work_dir)
interpreter.stop(remove=True)
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) # Track imports
state_manager.track_imports(code)
print("Starting Python interpreter in Docker...") # Load state from Docker if available
interpreter.start("persistent_python_interpreter2") locals_dict = state_manager.load_state('local')
snippet = "import pandas as pd" # Execute in a new namespace with loaded state
output, more = interpreter.execute(snippet) namespace = {}
print("OUTPUT1") namespace.update(locals_dict)
print(output, end='')
snippet = "pd.DataFrame()" output = evaluate_python_code(
output, more = interpreter.execute(snippet) code,
print("OUTPUT2") tools,
print(output, end='') {},
namespace,
LIST_SAFE_MODULES,
)
# Save state for Docker
state_manager.save_state(namespace, 'local')
return output
print("\nStopping interpreter...") def create_tools_regex(tool_names):
interpreter.stop(remove=True) # Escape any special regex characters in tool names
escaped_names = [re.escape(name) for name in tool_names]
# Join with | and add word boundaries
pattern = r'\b(' + '|'.join(escaped_names) + r')\b'
return re.compile(pattern)
if __name__ == '__main__': def execute_code(code: str, tools: Dict[str, Any], work_dir: Path, interpreter):
main() """Execute code with automatic switching between Docker and local."""
lines = code.split('\n')
current_block = []
tool_regex = create_tools_regex(list(tools.keys()) + ["print"]) # Added print for testing
tools = {
**BASE_PYTHON_TOOLS.copy(),
**tools,
}
for line in lines:
if tool_regex.search(line):
# Execute accumulated Docker code if any
if current_block:
output, more = interpreter.execute('\n'.join(current_block))
print(output, end='')
current_block = []
output = execute_locally(line, work_dir, tools)
if output:
print(output, end='')
else:
current_block.append(line)
# Execute any remaining Docker code
if current_block:
output, more = interpreter.execute('\n'.join(current_block))
print(output, end='')
__all__ = ["DockerPythonInterpreter", "execute_code"]