Auto correct wrong assignments to final_answer (#123)
* Auto correct wrong assignments to final_answer
This commit is contained in:
		
							parent
							
								
									e5d879feab
								
							
						
					
					
						commit
						d3cd0f9e09
					
				|  | @ -57,3 +57,6 @@ jobs: | ||||||
|       - name: Types tests |       - name: Types tests | ||||||
|         run: | |         run: | | ||||||
|           uv run pytest -sv ./tests/test_types.py |           uv run pytest -sv ./tests/test_types.py | ||||||
|  |       - name: Utils tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest -sv ./tests/test_utils.py | ||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							|  | @ -26,7 +26,11 @@ from rich.text import Text | ||||||
| 
 | 
 | ||||||
| from .default_tools import FinalAnswerTool | from .default_tools import FinalAnswerTool | ||||||
| from .e2b_executor import E2BExecutor | from .e2b_executor import E2BExecutor | ||||||
| from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter | from .local_python_executor import ( | ||||||
|  |     BASE_BUILTIN_MODULES, | ||||||
|  |     LocalPythonInterpreter, | ||||||
|  |     fix_final_answer_code, | ||||||
|  | ) | ||||||
| from .models import MessageRole | from .models import MessageRole | ||||||
| from .monitoring import Monitor | from .monitoring import Monitor | ||||||
| from .prompts import ( | from .prompts import ( | ||||||
|  | @ -895,7 +899,6 @@ class CodeAgent(MultiStepAgent): | ||||||
|             ) |             ) | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             console.print_exception() |  | ||||||
|             raise AgentGenerationError(f"Error in generating model output:\n{e}") |             raise AgentGenerationError(f"Error in generating model output:\n{e}") | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|  | @ -917,10 +920,11 @@ class CodeAgent(MultiStepAgent): | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         try: |         try: | ||||||
|             code_action = parse_code_blob(llm_output) |             code_action = fix_final_answer_code(parse_code_blob(llm_output)) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             console.print_exception() |             error_msg = ( | ||||||
|             error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" |                 f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." | ||||||
|  |             ) | ||||||
|             raise AgentParsingError(error_msg) |             raise AgentParsingError(error_msg) | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_call = ToolCall( |         log_entry.tool_call = ToolCall( | ||||||
|  | @ -944,8 +948,9 @@ class CodeAgent(MultiStepAgent): | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         observation = "" |         observation = "" | ||||||
|  |         is_final_answer = False | ||||||
|         try: |         try: | ||||||
|             output, execution_logs = self.python_executor( |             output, execution_logs, is_final_answer = self.python_executor( | ||||||
|                 code_action, |                 code_action, | ||||||
|                 self.state, |                 self.state, | ||||||
|             ) |             ) | ||||||
|  | @ -976,12 +981,6 @@ class CodeAgent(MultiStepAgent): | ||||||
|         observation += "Last output from code snippet:\n" + truncated_output |         observation += "Last output from code snippet:\n" + truncated_output | ||||||
|         log_entry.observations = observation |         log_entry.observations = observation | ||||||
| 
 | 
 | ||||||
|         is_final_answer = False |  | ||||||
|         for line in code_action.split("\n"): |  | ||||||
|             if line[: len("final_answer")] == "final_answer": |  | ||||||
|                 is_final_answer = True |  | ||||||
|                 break |  | ||||||
| 
 |  | ||||||
|         execution_outputs_console += [ |         execution_outputs_console += [ | ||||||
|             Text( |             Text( | ||||||
|                 f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}", |                 f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}", | ||||||
|  |  | ||||||
|  | @ -112,7 +112,7 @@ class PythonInterpreterTool(Tool): | ||||||
|                     state=state, |                     state=state, | ||||||
|                     static_tools=self.base_python_tools, |                     static_tools=self.base_python_tools, | ||||||
|                     authorized_imports=self.authorized_imports, |                     authorized_imports=self.authorized_imports, | ||||||
|                 ) |                 )[0]  # The second element is boolean is_final_answer | ||||||
|             ) |             ) | ||||||
|             return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" |             return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ import ast | ||||||
| import builtins | import builtins | ||||||
| import difflib | import difflib | ||||||
| import math | import math | ||||||
|  | import re | ||||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from typing import Any, Callable, Dict, List, Optional, Tuple | from typing import Any, Callable, Dict, List, Optional, Tuple | ||||||
|  | @ -129,6 +130,34 @@ def get_iterable(obj): | ||||||
|         raise InterpreterError("Object is not iterable") |         raise InterpreterError("Object is not iterable") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def fix_final_answer_code(code: str) -> str: | ||||||
|  |     """ | ||||||
|  |     Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool. | ||||||
|  |     This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable, | ||||||
|  |     while preserving function calls to final_answer(). | ||||||
|  |     """ | ||||||
|  |     # First, find if there's a direct assignment to final_answer | ||||||
|  |     # Use word boundary and negative lookbehind to ensure it's not an object attribute | ||||||
|  |     assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*=" | ||||||
|  |     if "final_answer(" not in code or not re.search(assignment_pattern, code): | ||||||
|  |         # If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps. | ||||||
|  |         # Let's not modify the code and leave the subsequent assignment error happen. | ||||||
|  |         return code | ||||||
|  | 
 | ||||||
|  |     # Pattern for replacing variable assignments | ||||||
|  |     # Looks for 'final_answer' followed by '=' with optional whitespace | ||||||
|  |     # Negative lookbehind ensures we don't match object attributes | ||||||
|  |     assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)" | ||||||
|  |     code = re.sub(assignment_regex, r"final_answer_variable\2", code) | ||||||
|  | 
 | ||||||
|  |     # Pattern for replacing variable usage but not function calls | ||||||
|  |     # Negative lookahead (?!\s*\() ensures we don't match function calls | ||||||
|  |     # Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables | ||||||
|  |     variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()" | ||||||
|  |     code = re.sub(variable_regex, "final_answer_variable", code) | ||||||
|  |     return code | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def evaluate_unaryop(expression, state, static_tools, custom_tools): | def evaluate_unaryop(expression, state, static_tools, custom_tools): | ||||||
|     operand = evaluate_ast(expression.operand, state, static_tools, custom_tools) |     operand = evaluate_ast(expression.operand, state, static_tools, custom_tools) | ||||||
|     if isinstance(expression.op, ast.USub): |     if isinstance(expression.op, ast.USub): | ||||||
|  | @ -224,6 +253,10 @@ def create_function(func_def, state, static_tools, custom_tools): | ||||||
|                 result = evaluate_ast(stmt, func_state, static_tools, custom_tools) |                 result = evaluate_ast(stmt, func_state, static_tools, custom_tools) | ||||||
|         except ReturnException as e: |         except ReturnException as e: | ||||||
|             result = e.value |             result = e.value | ||||||
|  | 
 | ||||||
|  |         if func_def.name == "__init__": | ||||||
|  |             return None | ||||||
|  | 
 | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
|     return new_func |     return new_func | ||||||
|  | @ -484,41 +517,31 @@ def evaluate_call(call, state, static_tools, custom_tools): | ||||||
|         for keyword in call.keywords |         for keyword in call.keywords | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if ( |     if func_name == "super": | ||||||
|         isinstance(func, type) and len(func.__module__.split(".")) > 1 |         if not args: | ||||||
|     ):  # Check for user-defined classes |             if "__class__" in state and "self" in state: | ||||||
|         # Instantiate the class using its constructor |                 return super(state["__class__"], state["self"]) | ||||||
|         obj = func.__new__(func)  # Create a new instance of the class |  | ||||||
|         if hasattr(obj, "__init__"):  # Check if the class has an __init__ method |  | ||||||
|             obj.__init__(*args, **kwargs)  # Call the __init__ method correctly |  | ||||||
|         return obj |  | ||||||
|     else: |  | ||||||
|         if func_name == "super": |  | ||||||
|             if not args: |  | ||||||
|                 if "__class__" in state and "self" in state: |  | ||||||
|                     return super(state["__class__"], state["self"]) |  | ||||||
|                 else: |  | ||||||
|                     raise InterpreterError("super() needs at least one argument") |  | ||||||
|             cls = args[0] |  | ||||||
|             if not isinstance(cls, type): |  | ||||||
|                 raise InterpreterError("super() argument 1 must be type") |  | ||||||
|             if len(args) == 1: |  | ||||||
|                 return super(cls) |  | ||||||
|             elif len(args) == 2: |  | ||||||
|                 instance = args[1] |  | ||||||
|                 return super(cls, instance) |  | ||||||
|             else: |             else: | ||||||
|                 raise InterpreterError("super() takes at most 2 arguments") |                 raise InterpreterError("super() needs at least one argument") | ||||||
|  |         cls = args[0] | ||||||
|  |         if not isinstance(cls, type): | ||||||
|  |             raise InterpreterError("super() argument 1 must be type") | ||||||
|  |         if len(args) == 1: | ||||||
|  |             return super(cls) | ||||||
|  |         elif len(args) == 2: | ||||||
|  |             instance = args[1] | ||||||
|  |             return super(cls, instance) | ||||||
|         else: |         else: | ||||||
|             if func_name == "print": |             raise InterpreterError("super() takes at most 2 arguments") | ||||||
|                 output = " ".join(map(str, args)) |     else: | ||||||
|                 global PRINT_OUTPUTS |         if func_name == "print": | ||||||
|                 PRINT_OUTPUTS += output + "\n" |             output = " ".join(map(str, args)) | ||||||
|                 # cap the number of lines |             global PRINT_OUTPUTS | ||||||
|                 return None |             PRINT_OUTPUTS += output + "\n" | ||||||
|             else:  # Assume it's a callable object |             # cap the number of lines | ||||||
|                 output = func(*args, **kwargs) |             return None | ||||||
|                 return output |         else:  # Assume it's a callable object | ||||||
|  |             return func(*args, **kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate_subscript(subscript, state, static_tools, custom_tools): | def evaluate_subscript(subscript, state, static_tools, custom_tools): | ||||||
|  | @ -990,6 +1013,11 @@ def truncate_print_outputs( | ||||||
|         return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n" |         return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class FinalAnswerException(Exception): | ||||||
|  |     def __init__(self, value): | ||||||
|  |         self.value = value | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def evaluate_python_code( | def evaluate_python_code( | ||||||
|     code: str, |     code: str, | ||||||
|     static_tools: Optional[Dict[str, Callable]] = None, |     static_tools: Optional[Dict[str, Callable]] = None, | ||||||
|  | @ -1029,6 +1057,12 @@ def evaluate_python_code( | ||||||
|     PRINT_OUTPUTS = "" |     PRINT_OUTPUTS = "" | ||||||
|     global OPERATIONS_COUNT |     global OPERATIONS_COUNT | ||||||
|     OPERATIONS_COUNT = 0 |     OPERATIONS_COUNT = 0 | ||||||
|  | 
 | ||||||
|  |     def final_answer(value): | ||||||
|  |         raise FinalAnswerException(value) | ||||||
|  | 
 | ||||||
|  |     static_tools["final_answer"] = final_answer | ||||||
|  | 
 | ||||||
|     try: |     try: | ||||||
|         for node in expression.body: |         for node in expression.body: | ||||||
|             result = evaluate_ast( |             result = evaluate_ast( | ||||||
|  | @ -1037,7 +1071,14 @@ def evaluate_python_code( | ||||||
|         state["print_outputs"] = truncate_content( |         state["print_outputs"] = truncate_content( | ||||||
|             PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT |             PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT | ||||||
|         ) |         ) | ||||||
|         return result |         is_final_answer = False | ||||||
|  |         return result, is_final_answer | ||||||
|  |     except FinalAnswerException as e: | ||||||
|  |         state["print_outputs"] = truncate_content( | ||||||
|  |             PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT | ||||||
|  |         ) | ||||||
|  |         is_final_answer = True | ||||||
|  |         return e.value, is_final_answer | ||||||
|     except InterpreterError as e: |     except InterpreterError as e: | ||||||
|         msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) |         msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) | ||||||
|         msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" |         msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" | ||||||
|  | @ -1059,9 +1100,11 @@ class LocalPythonInterpreter: | ||||||
|         } |         } | ||||||
|         # TODO: assert self.authorized imports are all installed locally |         # TODO: assert self.authorized imports are all installed locally | ||||||
| 
 | 
 | ||||||
|     def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]: |     def __call__( | ||||||
|  |         self, code_action: str, additional_variables: Dict | ||||||
|  |     ) -> Tuple[Any, str, bool]: | ||||||
|         self.state.update(additional_variables) |         self.state.update(additional_variables) | ||||||
|         output = evaluate_python_code( |         output, is_final_answer = evaluate_python_code( | ||||||
|             code_action, |             code_action, | ||||||
|             static_tools=self.static_tools, |             static_tools=self.static_tools, | ||||||
|             custom_tools=self.custom_tools, |             custom_tools=self.custom_tools, | ||||||
|  | @ -1069,7 +1112,7 @@ class LocalPythonInterpreter: | ||||||
|             authorized_imports=self.authorized_imports, |             authorized_imports=self.authorized_imports, | ||||||
|         ) |         ) | ||||||
|         logs = self.state["print_outputs"] |         logs = self.state["print_outputs"] | ||||||
|         return output, logs |         return output, logs, is_final_answer | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = ["evaluate_python_code", "LocalPythonInterpreter"] | __all__ = ["evaluate_python_code", "LocalPythonInterpreter"] | ||||||
|  |  | ||||||
|  | @ -373,7 +373,7 @@ Here are the rules you should always follow to solve your task: | ||||||
| 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. | 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. | ||||||
| 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. | 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. | ||||||
| 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. | 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. | ||||||
| 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables. | 7. Never create any notional variables in our code, as having these in your logs will derail you from the true variables. | ||||||
| 8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}} | 8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}} | ||||||
| 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. | 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. | ||||||
| 10. Don't give up! You're in charge of solving the task, not providing directions to solve it. | 10. Don't give up! You're in charge of solving the task, not providing directions to solve it. | ||||||
|  |  | ||||||
|  | @ -106,26 +106,35 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_code_blob(code_blob: str) -> str: | def parse_code_blob(code_blob: str) -> str: | ||||||
|     try: |     """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code.""" | ||||||
|         pattern = r"```(?:py|python)?\n(.*?)\n```" |     pattern = r"```(?:py|python)?\n(.*?)\n```" | ||||||
|         match = re.search(pattern, code_blob, re.DOTALL) |     match = re.search(pattern, code_blob, re.DOTALL) | ||||||
|         if match is None: |     if match is None: | ||||||
|             raise ValueError( |         try:  # Maybe the LLM outputted a code blob directly | ||||||
|                 f"No match ground for regex pattern {pattern} in {code_blob=}." |             ast.parse(code_blob) | ||||||
|             ) |             return code_blob | ||||||
|         return match.group(1).strip() |         except SyntaxError: | ||||||
|  |             pass | ||||||
| 
 | 
 | ||||||
|     except Exception as e: |         if "final" in code_blob and "answer" in code_blob: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f""" | ||||||
|  | The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. It seems like you're trying to return the final answer, you can do it as follows: | ||||||
|  | Code: | ||||||
|  | ```py | ||||||
|  | final_answer("YOUR FINAL ANSWER HERE") | ||||||
|  | ```<end_action>""".strip() | ||||||
|  |             ) | ||||||
|         raise ValueError( |         raise ValueError( | ||||||
|             f""" |             f""" | ||||||
| The code blob you used is invalid: due to the following error: {e} | The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. Make sure to include code with the correct pattern, for instance: | ||||||
| This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance: |  | ||||||
| Thoughts: Your thoughts | Thoughts: Your thoughts | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| # Your python code here | # Your python code here | ||||||
| ```<end_action>""" | ```<end_action>""".strip() | ||||||
|         ) |         ) | ||||||
|  |     return match.group(1).strip() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: | def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: | ||||||
|  |  | ||||||
|  | @ -444,3 +444,18 @@ final_answer("Final report.") | ||||||
| 
 | 
 | ||||||
|         report = manager_toolcalling_agent.run("Fake question.") |         report = manager_toolcalling_agent.run("Fake question.") | ||||||
|         assert report == "Final report." |         assert report == "Final report." | ||||||
|  | 
 | ||||||
|  |     def test_code_nontrivial_final_answer_works(self): | ||||||
|  |         def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): | ||||||
|  |             return """Code: | ||||||
|  | ```py | ||||||
|  | def nested_answer(): | ||||||
|  |     final_answer("Correct!") | ||||||
|  | 
 | ||||||
|  | nested_answer() | ||||||
|  | ```<end_code>""" | ||||||
|  | 
 | ||||||
|  |         agent = CodeAgent(tools=[], model=fake_code_model_final_answer) | ||||||
|  | 
 | ||||||
|  |         output = agent.run("Count to 3") | ||||||
|  |         assert output == "Correct!" | ||||||
|  |  | ||||||
|  | @ -23,6 +23,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS | ||||||
| from smolagents.local_python_executor import ( | from smolagents.local_python_executor import ( | ||||||
|     InterpreterError, |     InterpreterError, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
|  |     fix_final_answer_code, | ||||||
| ) | ) | ||||||
| from smolagents.types import AGENT_TYPE_MAPPING | from smolagents.types import AGENT_TYPE_MAPPING | ||||||
| 
 | 
 | ||||||
|  | @ -79,19 +80,19 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_assign(self): |     def test_evaluate_assign(self): | ||||||
|         code = "x = 3" |         code = "x = 3" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         assert result == 3 |         assert result == 3 | ||||||
|         self.assertDictEqual(state, {"x": 3, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|         code = "x = y" |         code = "x = y" | ||||||
|         state = {"y": 5} |         state = {"y": 5} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|         code = "a=1;b=None" |         code = "a=1;b=None" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result is None |         assert result is None | ||||||
| 
 | 
 | ||||||
|  | @ -107,7 +108,7 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_call(self): |     def test_evaluate_call(self): | ||||||
|         code = "y = add_two(x)" |         code = "y = add_two(x)" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|  | @ -119,14 +120,14 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_constant(self): |     def test_evaluate_constant(self): | ||||||
|         code = "x = 3" |         code = "x = 3" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         assert result == 3 |         assert result == 3 | ||||||
|         self.assertDictEqual(state, {"x": 3, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_dict(self): |     def test_evaluate_dict(self): | ||||||
|         code = "test_dict = {'x': x, 'y': add_two(x)}" |         code = "test_dict = {'x': x, 'y': add_two(x)}" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         self.assertDictEqual(result, {"x": 3, "y": 5}) |         self.assertDictEqual(result, {"x": 3, "y": 5}) | ||||||
|         self.assertDictEqual( |         self.assertDictEqual( | ||||||
|             state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} |             state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} | ||||||
|  | @ -135,7 +136,7 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_expression(self): |     def test_evaluate_expression(self): | ||||||
|         code = "x = 3\ny = 5" |         code = "x = 3\ny = 5" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) | ||||||
|  | @ -143,7 +144,7 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_f_string(self): |     def test_evaluate_f_string(self): | ||||||
|         code = "text = f'This is x: {x}.'" |         code = "text = f'This is x: {x}.'" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result == "This is x: 3." |         assert result == "This is x: 3." | ||||||
|         self.assertDictEqual( |         self.assertDictEqual( | ||||||
|  | @ -153,13 +154,13 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_if(self): |     def test_evaluate_if(self): | ||||||
|         code = "if x <= 3:\n    y = 2\nelse:\n    y = 5" |         code = "if x <= 3:\n    y = 2\nelse:\n    y = 5" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result == 2 |         assert result == 2 | ||||||
|         self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|         state = {"x": 8} |         state = {"x": 8} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""}) | ||||||
|  | @ -167,27 +168,27 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_list(self): |     def test_evaluate_list(self): | ||||||
|         code = "test_list = [x, add_two(x)]" |         code = "test_list = [x, add_two(x)]" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         self.assertListEqual(result, [3, 5]) |         self.assertListEqual(result, [3, 5]) | ||||||
|         self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_name(self): |     def test_evaluate_name(self): | ||||||
|         code = "y = x" |         code = "y = x" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         assert result == 3 |         assert result == 3 | ||||||
|         self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_subscript(self): |     def test_evaluate_subscript(self): | ||||||
|         code = "test_list = [x, add_two(x)]\ntest_list[1]" |         code = "test_list = [x, add_two(x)]\ntest_list[1]" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|         code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']" |         code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']" | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual( |         self.assertDictEqual( | ||||||
|             state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} |             state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} | ||||||
|  | @ -215,14 +216,14 @@ for result in search_results: | ||||||
|     def test_evaluate_for(self): |     def test_evaluate_for(self): | ||||||
|         code = "x = 0\nfor i in range(3):\n    x = i" |         code = "x = 0\nfor i in range(3):\n    x = i" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {"range": range}, state=state) |         result, _ = evaluate_python_code(code, {"range": range}, state=state) | ||||||
|         assert result == 2 |         assert result == 2 | ||||||
|         self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_binop(self): |     def test_evaluate_binop(self): | ||||||
|         code = "y + x" |         code = "y + x" | ||||||
|         state = {"x": 3, "y": 6} |         state = {"x": 3, "y": 6} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|         assert result == 9 |         assert result == 9 | ||||||
|         self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""}) |         self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""}) | ||||||
| 
 | 
 | ||||||
|  | @ -234,27 +235,27 @@ def recur_fibo(n): | ||||||
|     else: |     else: | ||||||
|         return(recur_fibo(n-1) + recur_fibo(n-2)) |         return(recur_fibo(n-1) + recur_fibo(n-2)) | ||||||
| recur_fibo(6)""" | recur_fibo(6)""" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == 8 |         assert result == 8 | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_string_methods(self): |     def test_evaluate_string_methods(self): | ||||||
|         code = "'hello'.replace('h', 'o').split('e')" |         code = "'hello'.replace('h', 'o').split('e')" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == ["o", "llo"] |         assert result == ["o", "llo"] | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_slicing(self): |     def test_evaluate_slicing(self): | ||||||
|         code = "'hello'[1:3][::-1]" |         code = "'hello'[1:3][::-1]" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == "le" |         assert result == "le" | ||||||
| 
 | 
 | ||||||
|     def test_access_attributes(self): |     def test_access_attributes(self): | ||||||
|         code = "integer = 1\nobj_class = integer.__class__\nobj_class" |         code = "integer = 1\nobj_class = integer.__class__\nobj_class" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result is int |         assert result is int | ||||||
| 
 | 
 | ||||||
|     def test_list_comprehension(self): |     def test_list_comprehension(self): | ||||||
|         code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])" |         code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == "t-h-e-s-e-a-g-u-l-l" |         assert result == "t-h-e-s-e-a-g-u-l-l" | ||||||
| 
 | 
 | ||||||
|     def test_string_indexing(self): |     def test_string_indexing(self): | ||||||
|  | @ -267,12 +268,12 @@ for block in text_block: | ||||||
|     for col in range(len(text_block[0])): |     for col in range(len(text_block[0])): | ||||||
|         sentence += block[col] |         sentence += block[col] | ||||||
|         """ |         """ | ||||||
|         result = evaluate_python_code(code, {"len": len, "range": range}, state={}) |         result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={}) | ||||||
|         assert result == "THESEAGULL" |         assert result == "THESEAGULL" | ||||||
| 
 | 
 | ||||||
|     def test_tuples(self): |     def test_tuples(self): | ||||||
|         code = "x = (1, 2, 3)\nx[1]" |         code = "x = (1, 2, 3)\nx[1]" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == 2 |         assert result == 2 | ||||||
| 
 | 
 | ||||||
|         code = """ |         code = """ | ||||||
|  | @ -325,35 +326,35 @@ print(check_digits) | ||||||
| 
 | 
 | ||||||
|     def test_listcomp(self): |     def test_listcomp(self): | ||||||
|         code = "x = [i for i in range(3)]" |         code = "x = [i for i in range(3)]" | ||||||
|         result = evaluate_python_code(code, {"range": range}, state={}) |         result, _ = evaluate_python_code(code, {"range": range}, state={}) | ||||||
|         assert result == [0, 1, 2] |         assert result == [0, 1, 2] | ||||||
| 
 | 
 | ||||||
|     def test_break_continue(self): |     def test_break_continue(self): | ||||||
|         code = "for i in range(10):\n    if i == 5:\n        break\ni" |         code = "for i in range(10):\n    if i == 5:\n        break\ni" | ||||||
|         result = evaluate_python_code(code, {"range": range}, state={}) |         result, _ = evaluate_python_code(code, {"range": range}, state={}) | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
| 
 | 
 | ||||||
|         code = "for i in range(10):\n    if i == 5:\n        continue\ni" |         code = "for i in range(10):\n    if i == 5:\n        continue\ni" | ||||||
|         result = evaluate_python_code(code, {"range": range}, state={}) |         result, _ = evaluate_python_code(code, {"range": range}, state={}) | ||||||
|         assert result == 9 |         assert result == 9 | ||||||
| 
 | 
 | ||||||
|     def test_call_int(self): |     def test_call_int(self): | ||||||
|         code = "import math\nstr(math.ceil(149))" |         code = "import math\nstr(math.ceil(149))" | ||||||
|         result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={}) |         result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={}) | ||||||
|         assert result == "149" |         assert result == "149" | ||||||
| 
 | 
 | ||||||
|     def test_lambda(self): |     def test_lambda(self): | ||||||
|         code = "f = lambda x: x + 2\nf(3)" |         code = "f = lambda x: x + 2\nf(3)" | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
| 
 | 
 | ||||||
|     def test_dictcomp(self): |     def test_dictcomp(self): | ||||||
|         code = "x = {i: i**2 for i in range(3)}" |         code = "x = {i: i**2 for i in range(3)}" | ||||||
|         result = evaluate_python_code(code, {"range": range}, state={}) |         result, _ = evaluate_python_code(code, {"range": range}, state={}) | ||||||
|         assert result == {0: 0, 1: 1, 2: 4} |         assert result == {0: 0, 1: 1, 2: 4} | ||||||
| 
 | 
 | ||||||
|         code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" |         code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, {"print": print}, state={}, authorized_imports=["pandas"] |             code, {"print": print}, state={}, authorized_imports=["pandas"] | ||||||
|         ) |         ) | ||||||
|         assert result == {102: "b"} |         assert result == {102: "b"} | ||||||
|  | @ -362,17 +363,17 @@ print(check_digits) | ||||||
| shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')} | shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')} | ||||||
| shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()} | shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()} | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code(code, {}, state={}) |         result, _ = evaluate_python_code(code, {}, state={}) | ||||||
|         assert result == {"A": ("a", "b"), "B": ("a", "b")} |         assert result == {"A": ("a", "b"), "B": ("a", "b")} | ||||||
| 
 | 
 | ||||||
|     def test_tuple_assignment(self): |     def test_tuple_assignment(self): | ||||||
|         code = "a, b = 0, 1\nb" |         code = "a, b = 0, 1\nb" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == 1 |         assert result == 1 | ||||||
| 
 | 
 | ||||||
|     def test_while(self): |     def test_while(self): | ||||||
|         code = "i = 0\nwhile i < 3:\n    i += 1\ni" |         code = "i = 0\nwhile i < 3:\n    i += 1\ni" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == 3 |         assert result == 3 | ||||||
| 
 | 
 | ||||||
|         # test infinite loop |         # test infinite loop | ||||||
|  | @ -393,7 +394,7 @@ while i < n and house_positions[i] <= loc: | ||||||
| 
 | 
 | ||||||
|     def test_generator(self): |     def test_generator(self): | ||||||
|         code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)" |         code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == [1, 4, 9, 16, 25] |         assert result == [1, 4, 9, 16, 25] | ||||||
| 
 | 
 | ||||||
|     def test_boolops(self): |     def test_boolops(self): | ||||||
|  | @ -403,7 +404,7 @@ else: | ||||||
|     best_city = "Manhattan" |     best_city = "Manhattan" | ||||||
|     best_city |     best_city | ||||||
|     """ |     """ | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} |             code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} | ||||||
|         ) |         ) | ||||||
|         assert result == "Brooklyn" |         assert result == "Brooklyn" | ||||||
|  | @ -416,7 +417,7 @@ else: | ||||||
|     best_city = "Manhattan" |     best_city = "Manhattan" | ||||||
|     best_city |     best_city | ||||||
|     """ |     """ | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} |             code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} | ||||||
|         ) |         ) | ||||||
|         assert result == "Sacramento" |         assert result == "Sacramento" | ||||||
|  | @ -431,51 +432,51 @@ if char.isalpha(): | ||||||
| 
 | 
 | ||||||
|     def test_imports(self): |     def test_imports(self): | ||||||
|         code = "import math\nmath.sqrt(4)" |         code = "import math\nmath.sqrt(4)" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == 2.0 |         assert result == 2.0 | ||||||
| 
 | 
 | ||||||
|         code = ( |         code = ( | ||||||
|             "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" |             "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" | ||||||
|         ) |         ) | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == "lose" |         assert result == "lose" | ||||||
| 
 | 
 | ||||||
|         code = "import time, re\ntime.sleep(0.1)" |         code = "import time, re\ntime.sleep(0.1)" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result is None |         assert result is None | ||||||
| 
 | 
 | ||||||
|         code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()" |         code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == 1 |         assert result == 1 | ||||||
| 
 | 
 | ||||||
|         code = "import itertools\nlist(itertools.islice(range(10), 3))" |         code = "import itertools\nlist(itertools.islice(range(10), 3))" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == [0, 1, 2] |         assert result == [0, 1, 2] | ||||||
| 
 | 
 | ||||||
|         code = "import re\nre.search('a', 'abc').group()" |         code = "import re\nre.search('a', 'abc').group()" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == "a" |         assert result == "a" | ||||||
| 
 | 
 | ||||||
|         code = "import stat\nstat.S_ISREG(0o100644)" |         code = "import stat\nstat.S_ISREG(0o100644)" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result |         assert result | ||||||
| 
 | 
 | ||||||
|         code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])" |         code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == 2.8 |         assert result == 2.8 | ||||||
| 
 | 
 | ||||||
|         code = "import unicodedata\nunicodedata.name('A')" |         code = "import unicodedata\nunicodedata.name('A')" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == "LATIN CAPITAL LETTER A" |         assert result == "LATIN CAPITAL LETTER A" | ||||||
| 
 | 
 | ||||||
|         # Test submodules are handled properly, thus not raising error |         # Test submodules are handled properly, thus not raising error | ||||||
|         code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" |         code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] |             code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" |         code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] |             code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -491,25 +492,25 @@ if char.isalpha(): | ||||||
| 
 | 
 | ||||||
|     def test_multiple_comparators(self): |     def test_multiple_comparators(self): | ||||||
|         code = "0 <= -1 < 4 and 0 <= -5 < 4" |         code = "0 <= -1 < 4 and 0 <= -5 < 4" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert not result |         assert not result | ||||||
| 
 | 
 | ||||||
|         code = "0 <= 1 < 4 and 0 <= -5 < 4" |         code = "0 <= 1 < 4 and 0 <= -5 < 4" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert not result |         assert not result | ||||||
| 
 | 
 | ||||||
|         code = "0 <= 4 < 4 and 0 <= 3 < 4" |         code = "0 <= 4 < 4 and 0 <= 3 < 4" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert not result |         assert not result | ||||||
| 
 | 
 | ||||||
|         code = "0 <= 3 < 4 and 0 <= 3 < 4" |         code = "0 <= 3 < 4 and 0 <= 3 < 4" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result |         assert result | ||||||
| 
 | 
 | ||||||
|     def test_print_output(self): |     def test_print_output(self): | ||||||
|         code = "print('Hello world!')\nprint('Ok no one cares')" |         code = "print('Hello world!')\nprint('Ok no one cares')" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) | ||||||
|         assert result is None |         assert result is None | ||||||
|         assert state["print_outputs"] == "Hello world!\nOk no one cares\n" |         assert state["print_outputs"] == "Hello world!\nOk no one cares\n" | ||||||
| 
 | 
 | ||||||
|  | @ -525,7 +526,7 @@ function()""" | ||||||
| 
 | 
 | ||||||
|     def test_tuple_target_in_iterator(self): |     def test_tuple_target_in_iterator(self): | ||||||
|         code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" |         code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == "Samuel" |         assert result == "Samuel" | ||||||
| 
 | 
 | ||||||
|     def test_classes(self): |     def test_classes(self): | ||||||
|  | @ -618,7 +619,7 @@ def var_args_method(self, *args, **kwargs): | ||||||
| var_args_method(1, 2, 3, x=4, y=5) | var_args_method(1, 2, 3, x=4, y=5) | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {"sum": sum}, state=state) |         result, _ = evaluate_python_code(code, {"sum": sum}, state=state) | ||||||
|         assert result == 15 |         assert result == 15 | ||||||
| 
 | 
 | ||||||
|     def test_exceptions(self): |     def test_exceptions(self): | ||||||
|  | @ -648,7 +649,7 @@ except ValueError as e: | ||||||
|     def test_types_as_objects(self): |     def test_types_as_objects(self): | ||||||
|         code = "type_a = float(2); type_b = str; type_c = int" |         code = "type_a = float(2); type_b = str; type_c = int" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code( |         result, is_final_answer = evaluate_python_code( | ||||||
|             code, {"float": float, "str": str, "int": int}, state=state |             code, {"float": float, "str": str, "int": int}, state=state | ||||||
|         ) |         ) | ||||||
|         assert result is int |         assert result is int | ||||||
|  | @ -659,7 +660,7 @@ food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1} | ||||||
| unique_food_items = [item for item, count in food_item_counts.items() if count == 1] | unique_food_items = [item for item, count in food_item_counts.items() if count == 1] | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result, is_final_answer = evaluate_python_code(code, {}, state=state) | ||||||
|         assert result == ["orange", "pear"] |         assert result == ["orange", "pear"] | ||||||
| 
 | 
 | ||||||
|     def test_nonsimple_augassign(self): |     def test_nonsimple_augassign(self): | ||||||
|  | @ -742,8 +743,9 @@ def f(a, b=333, n=1000): | ||||||
|     return b + n |     return b + n | ||||||
| n = f(1, n=667) | n = f(1, n=667) | ||||||
| """ | """ | ||||||
|         res = evaluate_python_code(code, {}, {}) |         res, is_final_answer = evaluate_python_code(code, {}, {}) | ||||||
|         assert res == 1000 |         assert res == 1000 | ||||||
|  |         assert not is_final_answer | ||||||
| 
 | 
 | ||||||
|     def test_set(self): |     def test_set(self): | ||||||
|         code = """ |         code = """ | ||||||
|  | @ -767,8 +769,11 @@ while True: | ||||||
|         break |         break | ||||||
| 
 | 
 | ||||||
| i""" | i""" | ||||||
|         result = evaluate_python_code(code, {"print": print, "round": round}, state={}) |         result, is_final_answer = evaluate_python_code( | ||||||
|  |             code, {"print": print, "round": round}, state={} | ||||||
|  |         ) | ||||||
|         assert result == 3 |         assert result == 3 | ||||||
|  |         assert not is_final_answer | ||||||
| 
 | 
 | ||||||
|     def test_return(self): |     def test_return(self): | ||||||
|         # test early returns |         # test early returns | ||||||
|  | @ -781,7 +786,7 @@ def add_one(n, shift): | ||||||
| add_one(1, 1) | add_one(1, 1) | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code( |         result, is_final_answer = evaluate_python_code( | ||||||
|             code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state |             code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state | ||||||
|         ) |         ) | ||||||
|         assert result == 2 |         assert result == 2 | ||||||
|  | @ -794,7 +799,7 @@ def returns_none(a): | ||||||
| returns_none(1) | returns_none(1) | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code( |         result, is_final_answer = evaluate_python_code( | ||||||
|             code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state |             code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state | ||||||
|         ) |         ) | ||||||
|         assert result is None |         assert result is None | ||||||
|  | @ -812,7 +817,7 @@ out = [i for sublist in all_res for i in sublist] | ||||||
| out[:10] | out[:10] | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code( |         result, is_final_answer = evaluate_python_code( | ||||||
|             code, {"print": print, "range": range}, state=state |             code, {"print": print, "range": range}, state=state | ||||||
|         ) |         ) | ||||||
|         assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] |         assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] | ||||||
|  | @ -829,7 +834,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0] | ||||||
| parts_with_5_set_count[['Quantity', 'SetCount']].values[1] | parts_with_5_set_count[['Quantity', 'SetCount']].values[1] | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, {}, state=state, authorized_imports=["pandas"] |             code, {}, state=state, authorized_imports=["pandas"] | ||||||
|         ) |         ) | ||||||
|         assert np.array_equal(result, [-1, 5]) |         assert np.array_equal(result, [-1, 5]) | ||||||
|  | @ -842,7 +847,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]}) | ||||||
| # Filter the DataFrame to get only the rows with outdated atomic numbers | # Filter the DataFrame to get only the rows with outdated atomic numbers | ||||||
| filtered_df = df.loc[df['AtomicNumber'].isin([104])] | filtered_df = df.loc[df['AtomicNumber'].isin([104])] | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, {"print": print}, state={}, authorized_imports=["pandas"] |             code, {"print": print}, state={}, authorized_imports=["pandas"] | ||||||
|         ) |         ) | ||||||
|         assert np.array_equal(result.values[0], [104, 1]) |         assert np.array_equal(result.values[0], [104, 1]) | ||||||
|  | @ -855,7 +860,9 @@ data = pd.DataFrame.from_dict([ | ||||||
| ]) | ]) | ||||||
| survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() | survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"]) |         result, _ = evaluate_python_code( | ||||||
|  |             code, {}, state={}, authorized_imports=["pandas"] | ||||||
|  |         ) | ||||||
|         assert result.values[1] == 0.5 |         assert result.values[1] == 0.5 | ||||||
| 
 | 
 | ||||||
|     def test_starred(self): |     def test_starred(self): | ||||||
|  | @ -877,7 +884,7 @@ coords_barcelona = (41.3869, 2.1660) | ||||||
| 
 | 
 | ||||||
| distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) | distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code( |         result, _ = evaluate_python_code( | ||||||
|             code, {"print": print, "map": map}, state={}, authorized_imports=["math"] |             code, {"print": print, "map": map}, state={}, authorized_imports=["math"] | ||||||
|         ) |         ) | ||||||
|         assert round(result, 1) == 622395.4 |         assert round(result, 1) == 622395.4 | ||||||
|  | @ -894,5 +901,42 @@ for worker, (start, end) in shifts.items(): | ||||||
|     shift_intervals[worker] = end |     shift_intervals[worker] = end | ||||||
| shift_intervals | shift_intervals | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code(code, {"print": print, "map": map}, state={}) |         result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}) | ||||||
|         assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"} |         assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"} | ||||||
|  | 
 | ||||||
|  |     def test_fix_final_answer_code(self): | ||||||
|  |         test_cases = [ | ||||||
|  |             ( | ||||||
|  |                 "final_answer = 3.21\nfinal_answer(final_answer)", | ||||||
|  |                 "final_answer_variable = 3.21\nfinal_answer(final_answer_variable)", | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 "x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)", | ||||||
|  |                 "x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)", | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 "def func():\n    final_answer = 42\n    return final_answer(final_answer)", | ||||||
|  |                 "def func():\n    final_answer_variable = 42\n    return final_answer(final_answer_variable)", | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 "final_answer(5)  # Should not change function calls", | ||||||
|  |                 "final_answer(5)  # Should not change function calls", | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 "obj.final_answer = 5  # Should not change object attributes", | ||||||
|  |                 "obj.final_answer = 5  # Should not change object attributes", | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 "final_answer=3.21;final_answer(final_answer)", | ||||||
|  |                 "final_answer_variable=3.21;final_answer(final_answer_variable)", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         for i, (input_code, expected) in enumerate(test_cases, 1): | ||||||
|  |             result = fix_final_answer_code(input_code) | ||||||
|  |             assert result == expected, f""" | ||||||
|  |     Test case {i} failed: | ||||||
|  |     Input:    {input_code} | ||||||
|  |     Expected: {expected} | ||||||
|  |     Got:      {result} | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  | @ -1,86 +1,39 @@ | ||||||
| import os | # coding=utf-8 | ||||||
| import shutil | # Copyright 2024 HuggingFace Inc. | ||||||
| import tempfile | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | import pytest | ||||||
|  | 
 | ||||||
|  | from smolagents.utils import parse_code_blob | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def str_to_bool(value) -> int: | class AgentTextTests(unittest.TestCase): | ||||||
|     """ |     def test_parse_code_blob(self): | ||||||
|     Converts a string representation of truth to `True` (1) or `False` (0). |         with pytest.raises(ValueError): | ||||||
|  |             parse_code_blob("Wrong blob!") | ||||||
| 
 | 
 | ||||||
|     True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; |         # Parsing mardkwon with code blobs should work | ||||||
|     """ |         output = parse_code_blob(""" | ||||||
|     value = value.lower() | Here is how to solve the problem: | ||||||
|     if value in ("y", "yes", "t", "true", "on", "1"): | Code: | ||||||
|         return 1 | ```py | ||||||
|     elif value in ("n", "no", "f", "false", "off", "0"): | import numpy as np | ||||||
|         return 0 | ```<end_code> | ||||||
|     else: | """) | ||||||
|         raise ValueError(f"invalid truth value {value}") |         assert output == "import numpy as np" | ||||||
| 
 | 
 | ||||||
| 
 |         # Parsing code blobs should work | ||||||
| def get_int_from_env(env_keys, default): |         code_blob = "import numpy as np" | ||||||
|     """Returns the first positive env value found in the `env_keys` list or the default.""" |         output = parse_code_blob(code_blob) | ||||||
|     for e in env_keys: |         assert output == code_blob | ||||||
|         val = int(os.environ.get(e, -1)) |  | ||||||
|         if val >= 0: |  | ||||||
|             return val |  | ||||||
|     return default |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def parse_flag_from_env(key, default=False): |  | ||||||
|     """Returns truthy value for `key` from the env if available else the default.""" |  | ||||||
|     value = os.environ.get(key, str(default)) |  | ||||||
|     return ( |  | ||||||
|         str_to_bool(value) == 1 |  | ||||||
|     )  # As its name indicates `str_to_bool` actually returns an int... |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def skip(test_case): |  | ||||||
|     "Decorator that skips a test unconditionally" |  | ||||||
|     return unittest.skip("Test was skipped")(test_case) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def slow(test_case): |  | ||||||
|     """ |  | ||||||
|     Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a |  | ||||||
|     truthy value to run them. |  | ||||||
|     """ |  | ||||||
|     return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class TempDirTestCase(unittest.TestCase): |  | ||||||
|     """ |  | ||||||
|     A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its |  | ||||||
|     data at the start of a test, and then destroyes it at the end of the TestCase. |  | ||||||
| 
 |  | ||||||
|     Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases |  | ||||||
| 
 |  | ||||||
|     The temporary directory location will be stored in `self.tmpdir` |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     clear_on_setup = True |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def setUpClass(cls): |  | ||||||
|         "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`" |  | ||||||
|         cls.tmpdir = Path(tempfile.mkdtemp()) |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def tearDownClass(cls): |  | ||||||
|         "Remove `cls.tmpdir` after test suite has finished" |  | ||||||
|         if os.path.exists(cls.tmpdir): |  | ||||||
|             shutil.rmtree(cls.tmpdir) |  | ||||||
| 
 |  | ||||||
|     def setUp(self): |  | ||||||
|         "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`" |  | ||||||
|         if self.clear_on_setup: |  | ||||||
|             for path in self.tmpdir.glob("**/*"): |  | ||||||
|                 if path.is_file(): |  | ||||||
|                     path.unlink() |  | ||||||
|                 elif path.is_dir(): |  | ||||||
|                     shutil.rmtree(path) |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue