fix final_answer issue in e2b_executor (#319)
This commit is contained in:
		
							parent
							
								
									2105811da6
								
							
						
					
					
						commit
						33b38e6cb7
					
				|  | @ -14,6 +14,7 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import re | ||||||
| import base64 | import base64 | ||||||
| import pickle | import pickle | ||||||
| import textwrap | import textwrap | ||||||
|  | @ -45,6 +46,8 @@ class E2BExecutor: | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         self.custom_tools = {} |         self.custom_tools = {} | ||||||
|  |         self.final_answer = False | ||||||
|  |         self.final_answer_pattern = re.compile(r'^final_answer\((.*)\)$') | ||||||
|         self.sbx = Sandbox()  # "qywp2ctmu2q7jzprcf4j") |         self.sbx = Sandbox()  # "qywp2ctmu2q7jzprcf4j") | ||||||
|         # TODO: validate installing agents package or not |         # TODO: validate installing agents package or not | ||||||
|         # print("Installing agents package on remote executor...") |         # print("Installing agents package on remote executor...") | ||||||
|  | @ -85,6 +88,8 @@ class E2BExecutor: | ||||||
|         self.logger.log(tool_definition_execution.logs) |         self.logger.log(tool_definition_execution.logs) | ||||||
| 
 | 
 | ||||||
|     def run_code_raise_errors(self, code: str): |     def run_code_raise_errors(self, code: str): | ||||||
|  |         if self.final_answer_pattern.match(code): | ||||||
|  |             self.final_answer = True | ||||||
|         execution = self.sbx.run_code( |         execution = self.sbx.run_code( | ||||||
|             code, |             code, | ||||||
|         ) |         ) | ||||||
|  | @ -122,7 +127,7 @@ locals().update({key: value for key, value in pickle_dict.items()}) | ||||||
|         execution = self.run_code_raise_errors(code_action) |         execution = self.run_code_raise_errors(code_action) | ||||||
|         execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) |         execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||||
|         if not execution.results: |         if not execution.results: | ||||||
|             return None, execution_logs |             return None, execution_logs, self.final_answer | ||||||
|         else: |         else: | ||||||
|             for result in execution.results: |             for result in execution.results: | ||||||
|                 if result.is_main_result: |                 if result.is_main_result: | ||||||
|  | @ -144,7 +149,7 @@ locals().update({key: value for key, value in pickle_dict.items()}) | ||||||
|                         "text", |                         "text", | ||||||
|                     ]: |                     ]: | ||||||
|                         if getattr(result, attribute_name) is not None: |                         if getattr(result, attribute_name) is not None: | ||||||
|                             return getattr(result, attribute_name), execution_logs |                             return getattr(result, attribute_name), execution_logs, self.final_answer | ||||||
|             raise ValueError("No main result returned by executor!") |             raise ValueError("No main result returned by executor!") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -25,6 +25,9 @@ class MethodChecker(ast.NodeVisitor): | ||||||
|         self.class_attributes = class_attributes |         self.class_attributes = class_attributes | ||||||
|         self.errors = [] |         self.errors = [] | ||||||
|         self.check_imports = check_imports |         self.check_imports = check_imports | ||||||
|  |         self.typing_names = { | ||||||
|  |             'Any' | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|     def visit_arguments(self, node): |     def visit_arguments(self, node): | ||||||
|         """Collect function arguments""" |         """Collect function arguments""" | ||||||
|  | @ -97,6 +100,7 @@ class MethodChecker(ast.NodeVisitor): | ||||||
|                 or node.id in self.imports |                 or node.id in self.imports | ||||||
|                 or node.id in self.from_imports |                 or node.id in self.from_imports | ||||||
|                 or node.id in self.assigned_names |                 or node.id in self.assigned_names | ||||||
|  |                 or node.id in self.typing_names | ||||||
|             ): |             ): | ||||||
|                 self.errors.append(f"Name '{node.id}' is undefined.") |                 self.errors.append(f"Name '{node.id}' is undefined.") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue