Add support for non-bool comparison operators. (#612)
This commit is contained in:
		
							parent
							
								
									9b96199d00
								
							
						
					
					
						commit
						5fd0a2e86e
					
				|  | @ -713,7 +713,7 @@ def evaluate_condition( | ||||||
|     static_tools: Dict[str, Callable], |     static_tools: Dict[str, Callable], | ||||||
|     custom_tools: Dict[str, Callable], |     custom_tools: Dict[str, Callable], | ||||||
|     authorized_imports: List[str], |     authorized_imports: List[str], | ||||||
| ) -> bool: | ) -> bool | object: | ||||||
|     left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) |     left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) | ||||||
|     comparators = [ |     comparators = [ | ||||||
|         evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators |         evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators | ||||||
|  | @ -747,6 +747,9 @@ def evaluate_condition( | ||||||
|         else: |         else: | ||||||
|             raise InterpreterError(f"Operator not supported: {op}") |             raise InterpreterError(f"Operator not supported: {op}") | ||||||
| 
 | 
 | ||||||
|  |         if not isinstance(current_result, bool): | ||||||
|  |             return current_result | ||||||
|  | 
 | ||||||
|         result = result & current_result |         result = result & current_result | ||||||
|         current_left = comparator |         current_left = comparator | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1181,6 +1181,31 @@ def test_get_safe_module_handle_lazy_imports(): | ||||||
|     assert getattr(safe_module, "non_lazy_attribute") == "ok" |     assert getattr(safe_module, "non_lazy_attribute") == "ok" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_non_standard_comparisons(): | ||||||
|  |     code = """ | ||||||
|  | class NonStdEqualsResult: | ||||||
|  |     def __init__(self, left:object, right:object): | ||||||
|  |         self._left = left | ||||||
|  |         self._right = right | ||||||
|  |     def __str__(self) -> str: | ||||||
|  |         return f'{self._left}=={self._right}' | ||||||
|  | 
 | ||||||
|  | class NonStdComparisonClass: | ||||||
|  |     def __init__(self, value: str ): | ||||||
|  |         self._value = value | ||||||
|  |     def __str__(self): | ||||||
|  |         return self._value | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return NonStdEqualsResult(self, other) | ||||||
|  | a = NonStdComparisonClass("a") | ||||||
|  | b = NonStdComparisonClass("b") | ||||||
|  | result = a == b | ||||||
|  |     """ | ||||||
|  |     result, _ = evaluate_python_code(code, state={}) | ||||||
|  |     assert not isinstance(result, bool) | ||||||
|  |     assert str(result) == "a==b" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class TestPrintContainer: | class TestPrintContainer: | ||||||
|     def test_initial_value(self): |     def test_initial_value(self): | ||||||
|         pc = PrintContainer() |         pc = PrintContainer() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue